code/chapter11/accelerate_configs/README.md
本目录包含用于分布式训练的Accelerate配置文件。
数据并行(DDP) - 最简单的多GPU训练方案
使用方法:
accelerate launch --config_file accelerate_configs/multi_gpu_ddp.yaml train_script.py
DeepSpeed ZeRO-2 - 优化器状态分片
使用方法:
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml train_script.py
DeepSpeed ZeRO-3 - 完整模型分片
使用方法:
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml train_script.py
pip install accelerate deepspeed
方式1: 使用配置文件(推荐)
accelerate launch --config_file accelerate_configs/multi_gpu_ddp.yaml your_script.py
方式2: 交互式配置
accelerate config
方式3: 命令行参数
accelerate launch --num_processes 4 --mixed_precision fp16 your_script.py
# DDP训练(4卡)
accelerate launch --config_file accelerate_configs/multi_gpu_ddp.yaml 07_distributed_training.py
# DeepSpeed ZeRO-2训练(4卡)
accelerate launch --config_file accelerate_configs/deepspeed_zero2.yaml 07_distributed_training.py
# DeepSpeed ZeRO-3训练(4卡)
accelerate launch --config_file accelerate_configs/deepspeed_zero3.yaml 07_distributed_training.py
compute_environment: 计算环境(LOCAL_MACHINE/AMAZON_SAGEMAKER等)distributed_type: 分布式类型(MULTI_GPU/DEEPSPEED/FSDP等)num_processes: 总进程数(通常等于GPU数量)machine_rank: 机器编号(主节点为0)num_machines: 机器数量gpu_ids: 使用的GPU ID(all表示使用所有GPU)mixed_precision: 混合精度训练(no/fp16/bf16)zero_stage: ZeRO优化级别(1/2/3)
offload_optimizer_device: 优化器状态卸载设备(none/cpu/nvme)
offload_param_device: 模型参数卸载设备(none/cpu/nvme)
gradient_accumulation_steps: 梯度累积步数
gradient_clipping: 梯度裁剪阈值
zero3_init_flag: ZeRO-3初始化标志
分布式训练时,总batch size = per_device_batch_size × num_gpus × gradient_accumulation_steps
示例:
# 单GPU: batch_size=4, gradient_accumulation=4, 总batch=16
# 4GPU DDP: batch_size=4, gradient_accumulation=1, 总batch=16
使用线性缩放规则:
lr_new = lr_base × sqrt(total_batch_size_new / total_batch_size_base)
当显存不足时,可以增大gradient_accumulation_steps:
deepspeed_config:
gradient_accumulation_steps: 8 # 增大累积步数
accelerate env
可能原因:
解决方法:
可能原因:
解决方法:
# 启用调试日志
export ACCELERATE_LOG_LEVEL=INFO
export NCCL_DEBUG=INFO
# 增加超时时间
export NCCL_TIMEOUT=1800
num_machines和main_process_ip