optional-skills/mlops/torchtitan/references/float8.md
Float8 training provides substantial speedups for models where GEMMs are large enough that the FP8 tensorcore speedup outweighs dynamic quantization overhead.
USE_CPP=0 pip install git+https://github.com/pytorch/ao.git
Standard Float8 with tensorwise dynamic scaling:
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
--model.converters="quantize.linear.float8" \
--quantize.linear.float8.enable_fsdp_float8_all_gather \
--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp \
--compile.enable
| Argument | Description |
|---|---|
--model.converters="quantize.linear.float8" | Swap nn.Linear with Float8Linear |
--quantize.linear.float8.enable_fsdp_float8_all_gather | Communicate in float8 to save bandwidth |
--quantize.linear.float8.precompute_float8_dynamic_scale_for_fsdp | Single all-reduce for all AMAX/scales |
--compile.enable | Required - fuses float8 scaling/casting kernels |
Higher accuracy than tensorwise scaling:
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
--model.converters="quantize.linear.float8" \
--quantize.linear.float8.recipe_name rowwise \
--compile.enable
Not all layers benefit from Float8. Filter small layers:
--quantize.linear.float8.filter_fqns="attention.wk,attention.wv,output"
Automatically skip layers too small to benefit:
--quantize.linear.float8.filter_fqns="auto_filter_small_kn"
Thresholds based on H100 microbenchmarks where speedup > overhead.
[model]
converters = ["quantize.linear.float8"]
[quantize.linear.float8]
enable_fsdp_float8_all_gather = true
precompute_float8_dynamic_scale_for_fsdp = true
filter_fqns = ["output", "auto_filter_small_kn"]
[compile]
enable = true
components = ["model", "loss"]
Cast input and weight to float8 inside forward before calling torch._scaled_mm:
# Float8 matmul requires scales
torch._scaled_mm(input_fp8, weight_fp8, scale_a=scale_input, scale_b=scale_weight)
max(abs) across ranks for scale computationNet benefit: Float8 all-gather + amax communication can beat bf16/fp32 all-gather, depending on world size and message size.
max(abs) for sharded weights| Strategy | Status | Description |
|---|---|---|
| Tensorwise dynamic | Stable | Single scale per tensor |
| Rowwise dynamic | Alpha | Scale per row, higher accuracy |
From benchmarks on H100:
| Configuration | TPS/GPU | vs Baseline |
|---|---|---|
| FSDP only | 5,762 | - |
| FSDP + compile | 6,667 | +16% |
| FSDP + compile + Float8 | 8,532 | +48% |
Check torchao microbenchmarks for forward+backward pass speedups on "layer norm => linear => sigmoid" for different M,N,K sizes.
Rule of thumb: GEMMs with K,N > 4096 typically benefit from Float8.
For NVIDIA Blackwell GPUs, TorchTitan supports MXFP8 (Microscaling FP8) for both dense and MoE models. See docs/mxfp8.md for details.