docs/api-guide/router_replay.md
This document provides a detailed description of the "Router Replay" feature implemented within the Megatron-LM Core for Mixture-of-Experts (MoE) models.
This feature is designed to enhance determinism and analyzability in MoE model training and inference. It enables the model to load routing decisions from a predefined file and enforce their use during the forward pass, thereby bypassing the real-time routing computation.
The design follows the principles of being non-intrusive and on-demand, with the core idea of activating the replay logic only when explicitly requested by the user.
Core Components:
RouterReplay (located in megatron/core/transformer/moe/router_replay.py): A utility class for replaying MoE routing decisions. When enabled via the moe_enable_routing_replay flag, a separate instance of RouterReplay is created for each MoE layer's router. Each instance is responsible for loading routing data and providing the deterministic routing decisions for its corresponding layer during the forward pass.moe_enable_routing_replay (located in megatron/core/transformer/transformer_config.py): A boolean global configuration flag that serves as the sole entry point for enabling this feature.Workflow:
The feature supports different modes, such as recording and replaying, controlled by a RouterReplayAction.
moe_enable_routing_replay to True in the model configuration.moe_enable_routing_replay is true, each TopKRouter creates its own RouterReplay instance.record, forward_replay, backward_replay) on the RouterReplay instances.topk_routing_with_score_function checks the router_replay_action.record mode: The dynamically computed top-k expert indices are captured and stored.forward_replay mode: The function retrieves pre-loaded expert indices from target_topk_idx. These indices are used for the forward computation and are also appended to the replay_backward_list to prepare for the backward pass.router_replay_action is checked again.backward_replay mode: The function retrieves the expert indices for the corresponding micro-batch by popping them from the replay_backward_list. This mode is intended for training recomputation (e.g., activation checkpointing and pipeline recompute) so the same routing decisions are used during recompute/backward as in forward, ensuring determinism and correctness.The implementation cleanly separates the replay logic from the router's core computation.
megatron/core/transformer/transformer_config.py:
moe_enable_routing_replay: bool = False.megatron/core/transformer/moe/moe_utils.py:
RouterReplay class to manage the state for recording and replaying routing decisions for a single MoE layer.
target_topk_idx: An attribute holding the expert indices for the current micro-batch during forward replay mode.recorded_topk_idx: An attribute for storing the computed expert indices when in record mode.replay_backward_list: A list that accumulates the top-k indices used during the forward passes of a mini-batch. This list is consumed in FIFO order during the backward pass to ensure correctness under pipeline parallelism.set_target_indices(): A method to load the replay indices into target_topk_idx for the forward pass.record_indices(): A method to save the computed indices.topk_routing_with_score_function is modified to contain the core logic. It checks the router_replay_action on the router_replay instance and accordingly performs one of the following actions: computes and records indices, replays indices from target_topk_idx (for forward), replays indices from replay_backward_list (for backward), or falls through to the default dynamic routing.set_target_indices() prepares replay_backward_list so each micro-batch’s indices are available for recomputation.REPLAY_BACKWARD so indices are consumed in FIFO order to mirror the forward sequence.RouterReplay instance per MoE router layer when building the model.RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD).RouterReplay.get_recorded_data() and persist.RouterReplay.set_replay_data(list_of_tensors).RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD).REPLAY_BACKWARD during recomputation.replay_backward_list in FIFO order.RouterReplay.clear_global_indices(), RouterReplay.clear_global_router_replay_action(), and RouterReplay.clear_global_router_replay_instances() to restore default behavior and prevent memory leaks.topk_routing_with_score_functionimport torch
from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction
from megatron.core.transformer.moe.moe_utils import topk_routing_with_score_function
rr = RouterReplay()
# Record
RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD)
logits = torch.randn(8, 16)
probs_rec, routing_map_rec = topk_routing_with_score_function(
logits=logits, topk=2, use_pre_softmax=False, score_function="softmax", router_replay=rr,
)
recorded = rr.get_recorded_indices()
torch.save(recorded, "/tmp/replay.pt")
# Forward replay
rr.clear_router_replay_action()
rr.set_router_replay_action(RouterReplayAction.REPLAY_FORWARD)
target = torch.load("/tmp/replay.pt")
rr.set_target_indices(target)
probs_rep, routing_map_rep = topk_routing_with_score_function(
logits=logits, topk=2, use_pre_softmax=False, score_function="softmax", router_replay=rr,
)
RouterReplay.clear_global_router_replay_action()
RouterReplay.clear_global_indices()
RouterReplay.clear_global_router_replay_instances()
Here is a minimal code example showing how to use RouterReplay for recording and replaying:
import torch
import torch.distributed as dist
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction
# Initialize distributed training
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
# Create a transformer config with RouterReplay enabled
config = TransformerConfig(
num_experts=8,
expert_model_parallel_size=1,
num_top_k=2,
moe_enable_routing_replay=True
)
# Create a TopKRouter instance
router = TopKRouter(config)
# Generate sample input (batch_size, sequence_length, hidden_size)
logits = torch.randn(16, 32, 8).to(torch.cuda.current_device())
# -----------------
# 1. Recording Mode
# -----------------
print("=== Recording Mode ===")
# Set global router replay action to RECORD
RouterReplay.set_global_router_replay_action(RouterReplayAction.RECORD)
# Perform routing
routing_output = router.forward(logits)
print(f"Recorded top-k indices shape: {routing_output.top_k_idx.shape}")
# -----------------
# 2. Forward Replay Mode
# -----------------
print("\n=== Forward Replay Mode ===")
# Save recorded indices to a file
torch.save(routing_output.top_k_idx, "/tmp/replay.pt")
# Load indices from file and set as target for replay
replay_indices = torch.load("/tmp/replay.pt")
for router_instance in RouterReplay.global_router_replay_instances:
router_instance.target_topk_idx = replay_indices
# Set global router replay action to REPLAY_FORWARD
RouterReplay.set_global_router_replay_action(RouterReplayAction.REPLAY_FORWARD)
# Perform routing again - this will use the replayed indices
replay_routing_output = router.forward(logits)
print(f"Replayed top-k indices shape: {replay_routing_output.top_k_idx.shape}")
print(f"Are indices the same? {torch.equal(routing_output.top_k_idx, replay_routing_output.top_k_idx)}")
# Clean up
RouterReplay.clear_global_router_replay_action()
RouterReplay.clear_global_indices()
RouterReplay.clear_global_router_replay_instances()
if dist.is_initialized():
dist.destroy_process_group()