-
Notifications
You must be signed in to change notification settings - Fork 998
[megatron] support megatron MTP #6496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[megatron] support megatron MTP #6496
Conversation
Summary of ChangesHello @Jintao-Huang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Megatron framework by introducing support for Multi-Token Prediction (MTP) within the GPT model. The changes involve adding new configurable parameters for MTP layers and loss scaling, alongside a substantial refactoring of the GPT model's forward pass. This refactoring separates input preprocessing and output postprocessing into distinct methods, creating a more modular and maintainable architecture that seamlessly incorporates MTP as an additional training objective. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for Multi-Token Prediction (MTP) in Megatron, including new command-line arguments and documentation. The core change is a significant refactoring of the GPTModel.forward method to accommodate the MTP logic. My review focuses on the correctness of this refactoring and the consistency of the documentation. I've identified a critical breaking API change in the forward method signature that needs to be addressed. Additionally, there are some language inconsistencies in the documentation files that should be corrected for clarity.
swift/megatron/model/gpt_model.py
Outdated
| self, | ||
| input_ids: torch.Tensor, | ||
| position_ids: torch.Tensor, | ||
| attention_mask: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The signature of the forward method has been changed, making attention_mask a required positional argument. This is a breaking API change that will cause errors in other parts of the codebase that call this method without providing attention_mask or passing it as a keyword argument. For example, MultimodalGPTModel.forward in swift/megatron/model/mm_gpt_model.py will break. To maintain backward compatibility, attention_mask should remain an optional keyword argument.
| attention_mask: torch.Tensor, | |
| attention_mask: torch.Tensor = None, |
| - mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. | ||
| - mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The descriptions for the new MTP parameters are in English, but this documentation file is in Chinese. For consistency, please translate these descriptions into Chinese.
| - mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. | |
| - mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. | |
| - mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 | |
| - mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 |
| - mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 | ||
| - mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The descriptions for the new MTP parameters are in Chinese, but this documentation file is in English. For consistency, please translate these descriptions into English.
| - mtp_num_layers: 多token预测(MTP)层的数量。MTP将每个位置的预测范围扩展到多个未来token。此MTP实现使用D个顺序模块依次预测D个额外的token。默认为None。 | |
| - mtp_loss_scaling_factor: 多token预测(MTP)损失的缩放因子。我们计算所有深度上MTP损失的平均值,然后乘以该缩放因子得到总体MTP损失,它将作为一个额外的训练目标。默认为0.1。 | |
| - mtp_num_layers: Number of Multi-Token Prediction (MTP) layers. MTP extends the prediction scope at each position to multiple future tokens. This MTP implementation uses D sequential modules to sequentially predict D additional tokens. Default is None. | |
| - mtp_loss_scaling_factor: Scaling factor of Multi-Token Prediction (MTP) loss. We compute the average of MTP losses across all depths, then multiply it by this scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1. |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for Megatron's Multi-Token Prediction (MTP), which is a form of speculative decoding. The changes include adding new command-line arguments, documentation, and example scripts for MTP. Core logic changes involve modifying the model bridge for weight conversion of MTP layers, particularly for qwen3_next, and patching Megatron's MTP layer to handle packed sequences. The implementation also adds speculative decoding parameters for the sglang inference backend.
My review found a critical issue in the logic for handling packed sequences within the patched MTP layer, which could lead to incorrect behavior. I've provided a code suggestion to fix this. The rest of the changes, including documentation, examples, and argument handling, appear to be correct and well-implemented.
swift/megatron/init.py
Outdated
| if packed_seq: | ||
| packed_seq_params = deepcopy(packed_seq_params) | ||
| tensor = packed_seq_params.cu_seqlens_q | ||
| cu_seqlens = torch.concat([tensor.new_zeros(1, ), tensor[1:] - 1, tensor[-1:]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for adjusting cu_seqlens for packed sequences in the MTP layer appears to be incorrect. The current implementation does not correctly shorten each sequence in the batch by one token for the next-token prediction task. Instead, it seems to shorten only the first sequence and adds an erroneous sequence of length 1 at the end, which will likely lead to errors or incorrect behavior when using MTP with packed sequences.
A correct implementation should adjust the cumulative sequence lengths to account for each sequence being shortened by one. This can be achieved by subtracting an increasing sequence [0, 1, 2, ...] from the original cu_seqlens tensor.
| cu_seqlens = torch.concat([tensor.new_zeros(1, ), tensor[1:] - 1, tensor[-1:]]) | |
| cu_seqlens = tensor - torch.arange(len(tensor), device=tensor.device, dtype=tensor.dtype) |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for Megatron Multi-Token Prediction (MTP) and speculative decoding for vLLM/sglang backends. The changes are comprehensive, spanning documentation, example scripts, and core logic for both training and inference. The implementation appears solid, including new arguments, necessary monkey-patching for Megatron integration, and weight conversion logic. I have a couple of suggestions to enhance code robustness and consistency.
swift/megatron/init.py
Outdated
| if packed_seq and not self.config.apply_rope_fusion: | ||
| assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' | ||
| rotary_pos_emb = rotary_pos_emb[position_ids[0]] | ||
| if self.config.recompute_granularity == 'full' and self.training: | ||
| hidden_states = self._checkpointed_forward( | ||
| partial( | ||
| self._proj_and_transformer_layer, | ||
| packed_seq_params=packed_seq_params, | ||
| sequence_len_offset=sequence_len_offset, | ||
| ), | ||
| hidden_states=hidden_states, | ||
| decoder_input=decoder_input, | ||
| attention_mask=attention_mask, | ||
| context=context, | ||
| context_mask=context_mask, | ||
| rotary_pos_emb=rotary_pos_emb, | ||
| rotary_pos_cos=rotary_pos_cos, | ||
| rotary_pos_sin=rotary_pos_sin, | ||
| attention_bias=attention_bias, | ||
| inference_params=inference_params, | ||
| ) | ||
| else: | ||
| hidden_states = self._proj_and_transformer_layer( | ||
| hidden_states=hidden_states, | ||
| decoder_input=decoder_input, | ||
| attention_mask=attention_mask, | ||
| context=context, | ||
| context_mask=context_mask, | ||
| rotary_pos_emb=rotary_pos_emb, | ||
| rotary_pos_cos=rotary_pos_cos, | ||
| rotary_pos_sin=rotary_pos_sin, | ||
| attention_bias=attention_bias, | ||
| inference_params=inference_params, | ||
| packed_seq_params=packed_seq_params, | ||
| sequence_len_offset=sequence_len_offset, | ||
| ) | ||
| self.config.apply_rope_fusion = apply_rope_fusion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modifying self.config.apply_rope_fusion during the forward pass can be risky. If an exception occurs between setting it to False and restoring its original value, the config state will be left inconsistent, which could affect subsequent operations. It's safer to use a try...finally block to ensure the original value is always restored.
try:
if packed_seq and not self.config.apply_rope_fusion:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
rotary_pos_emb = rotary_pos_emb[position_ids[0]]
if self.config.recompute_granularity == 'full' and self.training:
hidden_states = self._checkpointed_forward(
partial(
self._proj_and_transformer_layer,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
),
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
)
else:
hidden_states = self._proj_and_transformer_layer(
hidden_states=hidden_states,
decoder_input=decoder_input,
attention_mask=attention_mask,
context=context,
context_mask=context_mask,
rotary_pos_emb=rotary_pos_emb,
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
attention_bias=attention_bias,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
finally:
self.config.apply_rope_fusion = apply_rope_fusion| self.vllm_limit_mm_per_prompt = json_parse_to_dict(self.vllm_limit_mm_per_prompt) | ||
| if self.vllm_speculative_config is not None: | ||
| self.vllm_speculative_config = json_parse_to_dict(self.vllm_speculative_config) | ||
| self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency with how vllm_limit_mm_per_prompt and vllm_speculative_config are handled, it's good practice to also check if vllm_engine_kwargs is not None before parsing it. While json_parse_to_dict might handle None gracefully, this change improves code clarity and robustness.
| self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs) | |
| if self.vllm_engine_kwargs is not None: | |
| self.vllm_engine_kwargs = json_parse_to_dict(self.vllm_engine_kwargs) |
No description provided.