docs/source/en/tensor_parallelism.md
Tensor parallelism (TP) splits weight matrices column-wise or row-wise across GPUs. Each GPU holds a shard, computes a partial result, and synchronizes with an all-reduce to produce the full output.
TP relies on frequent cross-GPU communication. It works best on hardware with fast intra-node links such as NVLink.
┌─────────────────────────────┐
│ X (replicated) │
└────┬──────────┬─────────┬───┘
│ │ │
┌────▼───┐ ┌────▼───┐ ┌───▼────┐
│ ▓▓▓ W₀ │ │ ░░░ W₁ │ │ ███ W₂ │
│ X@W₀ │ │ X@W₁ │ │ X@W₂ │
└────┬───┘ └────┬───┘ └───┬────┘
└──────────┼─────────┘
Y₀+Y₁+Y₂
┌────────────────────────────┐
│ Y (full) │
└────────────────────────────┘
Transformers supports TP for architectures whose config defines base_model_tp_plan. Check that field first to see whether a model supports native TP.
from transformers import AutoConfig
config = AutoConfig.from_pretrained("Qwen/Qwen3-0.6B")
print(config.base_model_tp_plan is not None)
print(config.base_model_tp_plan)
If a model supports TP, set tp_plan="auto" in [~PreTrainedModel.from_pretrained]. Transformers initializes the device mesh and shards the supported layers for you.
[!WARNING] Don't use
device_mapwithtp_plan. The two conflict at the weight-loading level.device_mapplaces whole modules on specific GPUs, whiletp_planshards those same parameters across all GPUs.
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
dtype=torch.bfloat16,
tp_plan="auto",
)
[Trainer] detects tp_plan, reads tp_size from the model, and creates a [~accelerate.parallelism_config.ParallelismConfig] automatically.
Launch training on one node with 4 GPUs.
torchrun --nproc-per-node 4 train_tp.py
Pass [~accelerate.parallelism_config.ParallelismConfig] explicitly when combining TP with other parallelism techniques like FSDP.
import torch
from accelerate import ParallelismConfig
from transformers import AutoModelForCausalLM, TrainingArguments
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
dtype=torch.bfloat16,
tp_plan="auto",
)
parallelism_config = ParallelismConfig(tp_size=4)
args = TrainingArguments(
...,
parallelism_config=parallelism_config,
)