Skip to content

[Feature] Add Multi-Token Prediction (MTP) module implementation#1570

Open
HAOCHENYE wants to merge 6 commits intogh/HAOCHENYE/17/basefrom
gh/HAOCHENYE/17/head
Open

[Feature] Add Multi-Token Prediction (MTP) module implementation#1570
HAOCHENYE wants to merge 6 commits intogh/HAOCHENYE/17/basefrom
gh/HAOCHENYE/17/head

Conversation

[ghstack-poisoned]
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 13, 2026
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 13, 2026
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 16, 2026
[ghstack-poisoned]
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 17, 2026
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 20, 2026
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 20, 2026
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 20, 2026
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 22, 2026
[ghstack-poisoned]
[ghstack-poisoned]
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 24, 2026
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 24, 2026
[ghstack-poisoned]
Copy link
Copy Markdown

@Kirrito-k423 Kirrito-k423 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this well-structured MTP implementation! The overall architecture is clean with good separation of concerns (MTPConfig, MTPLayer, MTPBlock). I have a few important issues and suggestions below.


# Step 3: Pass through the standard decoder layer
# This includes attention, MLP, and their respective normalizations
# TODO: TMP hardcode here.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Important: This TODO comment indicates unfinished logic. What is hardcoded here? If this is a temporary workaround, please clarify what needs to be done before merging, or create a follow-up issue.

# xtuner: mtp_block.layers.{idx}.enorm -> HF: mtp.pre_fc_norm_embedding
# xtuner: mtp_block.layers.{idx}.hnorm -> HF: mtp.pre_fc_norm_hidden
# xtuner: mtp_block.layers.{idx}.final_layernorm -> HF: mtp.norm
# Note: Currently assuming single MTP layer (idx=0), may need adjustment for multiple layers
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Important: The comment indicates this assumes single MTP layer, but MTPConfig.num_layers supports multiple layers. For num_layers > 1, the key mappings will be incorrect. Please either add validation to reject num_layers > 1 if not supported, or implement correct multi-layer key mapping.

attention mask, etc.

Returns:
list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: List of 3-tuples
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Suggestion: The return type annotation says list[tuple[...]], but the variable mtp_outputs and the append logic could be clearer. At line 105, mtp_outputs.append(current_hidden_states) appends a tuple but the variable name suggests a single tensor.

)

# Compute MTP losses for each depth
mtp_losses = torch.tensor(0.0, device=DEVICE)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Suggestion: Creating a tensor with torch.tensor(0.0, device=DEVICE) inside a loop can be inefficient. Consider initializing outside the loop with torch.zeros(1, device=DEVICE).

HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 25, 2026
[ghstack-poisoned]
HAOCHENYE added a commit to HAOCHENYE/xtuner that referenced this pull request Mar 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants