[Feature] Add Multi-Token Prediction (MTP) module implementation#1570
[Feature] Add Multi-Token Prediction (MTP) module implementation#1570HAOCHENYE wants to merge 6 commits intogh/HAOCHENYE/17/basefrom
Conversation
ghstack-source-id: 1b45af1 Pull-Request: InternLM#1570
ghstack-source-id: 1b45af1 Pull-Request: InternLM#1570
ghstack-source-id: 1b45af1 Pull-Request: InternLM#1570
ghstack-source-id: 2d84ad6 Pull-Request: InternLM#1570
ghstack-source-id: 2d84ad6 Pull-Request: InternLM#1570
ghstack-source-id: 2d84ad6 Pull-Request: InternLM#1570
ghstack-source-id: 2d84ad6 Pull-Request: InternLM#1570
ghstack-source-id: 2d84ad6 Pull-Request: InternLM#1570
ghstack-source-id: db84856 Pull-Request: InternLM#1570
ghstack-source-id: db84856 Pull-Request: InternLM#1570
Kirrito-k423
left a comment
There was a problem hiding this comment.
Thanks for this well-structured MTP implementation! The overall architecture is clean with good separation of concerns (MTPConfig, MTPLayer, MTPBlock). I have a few important issues and suggestions below.
|
|
||
| # Step 3: Pass through the standard decoder layer | ||
| # This includes attention, MLP, and their respective normalizations | ||
| # TODO: TMP hardcode here. |
There was a problem hiding this comment.
🟠 Important: This TODO comment indicates unfinished logic. What is hardcoded here? If this is a temporary workaround, please clarify what needs to be done before merging, or create a follow-up issue.
| # xtuner: mtp_block.layers.{idx}.enorm -> HF: mtp.pre_fc_norm_embedding | ||
| # xtuner: mtp_block.layers.{idx}.hnorm -> HF: mtp.pre_fc_norm_hidden | ||
| # xtuner: mtp_block.layers.{idx}.final_layernorm -> HF: mtp.norm | ||
| # Note: Currently assuming single MTP layer (idx=0), may need adjustment for multiple layers |
There was a problem hiding this comment.
🟠 Important: The comment indicates this assumes single MTP layer, but MTPConfig.num_layers supports multiple layers. For num_layers > 1, the key mappings will be incorrect. Please either add validation to reject num_layers > 1 if not supported, or implement correct multi-layer key mapping.
| attention mask, etc. | ||
|
|
||
| Returns: | ||
| list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: List of 3-tuples |
There was a problem hiding this comment.
🟡 Suggestion: The return type annotation says list[tuple[...]], but the variable mtp_outputs and the append logic could be clearer. At line 105, mtp_outputs.append(current_hidden_states) appends a tuple but the variable name suggests a single tensor.
| ) | ||
|
|
||
| # Compute MTP losses for each depth | ||
| mtp_losses = torch.tensor(0.0, device=DEVICE) |
There was a problem hiding this comment.
🟡 Suggestion: Creating a tensor with torch.tensor(0.0, device=DEVICE) inside a loop can be inefficient. Consider initializing outside the loop with torch.zeros(1, device=DEVICE).
ghstack-source-id: 8c98b3a Pull-Request: InternLM#1570
ghstack-source-id: e63ad27 Pull-Request: InternLM#1570
Stack from ghstack (oldest at bottom):