Back to Megatron Lm

Multi-Token Prediction (MTP)

docs/user-guide/features/multi_token_prediction.md

23.063.9 KB
Original Source
<!--- Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. NVIDIA CORPORATION and its licensors retain all intellectual property and proprietary rights in and to this software, related documentation and any modifications thereto. Any use, reproduction, disclosure or distribution of this software and related documentation without an express license agreement from NVIDIA CORPORATION is strictly prohibited. -->

Multi-Token Prediction (MTP)

Multi-Token Prediction (MTP) extends the prediction scope to multiple future tokens at each position. On the one hand, an MTP objective densifies the training signals and may improve data efficiency. On the other hand, MTP may enable the model to pre-plan its representations for better prediction of future tokens. In this implementation of MTP, we sequentially predict additional tokens and keep the complete causal chain at each prediction depth. The following figure illustrates our implementation of MTP in DeepSeek-V3.

The k-th MTP module consists of a shared embedding layer, a projection matrix, a Transformer block, and a shared output head. For the i-th input token at the (k - 1)-th prediction depth, we first combine the representation of the i-th token and the embedding of the (i + K)-th token with the linear projection. The combined serves as the input of the Transformer block at the k-th depth to produce the output representation.

For more information, refer to DeepSeek-V3 Technical Report

We can train GPTModel like models with Multi-Token Prediction (MTP) by setting mtp_num_layers to be a positive integer.

ItemDescription
mtp_num_layersNumber of Multi-Token Prediction (MTP) Layers. MTP extends the prediction scope to multiple future tokens at each position. This MTP implementation sequentially predict additional tokens by using D sequential modules to predict D additional tokens. Default is None.
mtp_loss_scaling_factorScaling factor of Multi-Token Prediction (MTP) loss. We compute the average of the MTP losses across all depths, and multiply it the scaling factor to obtain the overall MTP loss, which serves as an additional training objective. Default is 0.1.

Pipeline Parallel Layout for MTP

MTP supports flexible placement of MTP layers across pipeline stages using a custom pipeline_model_parallel_layout. By default, all MTP layers are placed on the last pipeline stage, but you can customize their placement.

MTP Standalone Mode

When MTP layers are placed in a separate virtual pipeline (vpp) stage that is not on the last pipeline rank, the mtp_standalone flag is automatically set to True. This mode enables MTP to run independently in its own pipeline stage.

Layout Format

Use m to represent MTP layers in the pipeline layout string. For example:

  • "E|t*3|(t|)*5mL" - MTP in the last stage
  • "E|t*3|(t|)*4tm|L" - MTP in the second-to-last stage with a decoder layer
  • "E|t*3|(t|)*3tt|m|L" - MTP in a standalone stage (second-to-last) with no other layers

Constraints

  • All MTP layers must be placed in the same one virtual pipeline stage.
  • MTP layers cannot be placed on the first pipeline rank.

Implementation Notes

  • For models with MTP layers, the final layernorm is placed in the stage that contains the last decoder layer, rather than in the post-process stage. This may cause small numerical differences in gradient norm reduction when final layernorm is placed in different pipeline stages in deterministic mode. Bitwise alignment can be achieved by disabling gradient norm clipping.
  • MTP loss is computed in the post-processing stage.

Precautions

Do not use Context Parallel (CP), or arbitrary AttnMaskType, or learned absolute position embedding type with MTP. These use cases are not yet supported.