examples/research_projects/pytorch_xla/inference/flux/README.md
The flux_inference script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on Trillium TPU versions. No other TPU types have been tested.
To create a TPU on Google Cloud, follow this guide
SSH into the VM and install Pytorch, Pytorch/XLA
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Verify that PyTorch and PyTorch/XLA were installed correctly:
python3 -c "import torch; import torch_xla;"
Clone the diffusers repo and install dependencies
git clone https://github.com/huggingface/diffusers.git
cd diffusers
pip install transformers accelerate sentencepiece structlog
pip install .
cd examples/research_projects/pytorch_xla/inference/flux/
Gated Model
As the model is gated, before using it with diffusers you first need to go to the FLUX.1 [dev] Hugging Face page, fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
hf auth login
Then run:
python flux_inference.py
The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.
On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel).
Note:
flux_inference.pyusesxmp.spawn(one process per chip) and requires the full model to fit on a single chip. If you run into OOM errors (e.g., on v5e with 16GB HBM per chip), use the SPMD version instead — see below.
On TPU configurations where a single chip cannot hold the full FLUX transformer (~16GB in bf16), use flux_inference_spmd.py. This script uses PyTorch/XLA SPMD to shard the transformer across multiple chips using a (data, model) mesh — 4-way model parallel so each chip holds ~4GB of weights, with the remaining chips for data parallelism.
python flux_inference_spmd.py --schnell
Key differences from flux_inference.py:
xmp.spawn — the XLA compiler handles all collective communication transparently."model" mesh axis using xs.mark_sharding.On a v5litepod-8 (v5e, 8 chips, 16GB HBM each) with FLUX.1-schnell, expect ~1.76 sec/image at steady state (after compilation):
2026-04-15 02:24:30 [info ] SPMD mesh: (2, 4), axes: ('data', 'model'), devices: 8
2026-04-15 02:24:30 [info ] encoding prompt on CPU...
2026-04-15 02:26:20 [info ] loading VAE on CPU...
2026-04-15 02:26:20 [info ] loading flux transformer from black-forest-labs/FLUX.1-schnell
2026-04-15 02:27:22 [info ] starting compilation run...
2026-04-15 02:52:55 [info ] compilation took 1533.4575625509997 sec.
2026-04-15 02:52:56 [info ] starting inference run...
2026-04-15 02:56:11 [info ] inference time: 195.74092420299985
2026-04-15 02:56:13 [info ] inference time: 1.7625778899996476
2026-04-15 02:56:13 [info ] avg. inference over 2 iterations took 98.75175104649975 sec.
The first inference iteration includes VAE compilation (~195s). The second iteration shows the true steady-state speed (~1.76s).
flux_inference.py)WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 7.06it/s]
Loading pipeline components...: 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 3/5 [00:00<00:00, 6.80it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 6.28it/s]
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-03-14 21:17:53 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
2025-03-14 21:17:54 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.66it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 4.48it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.32it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.69it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.74it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.10it/s]
2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 0%| | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:56 [info ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 3.55it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00, 1.46it/s]
2025-03-14 21:18:34 [info ] starting compilation run...
2025-03-14 21:18:37 [info ] starting compilation run...
2025-03-14 21:18:38 [info ] starting compilation run...
2025-03-14 21:18:39 [info ] starting compilation run...
2025-03-14 21:18:41 [info ] starting compilation run...
2025-03-14 21:18:41 [info ] starting compilation run...
2025-03-14 21:18:42 [info ] starting compilation run...
2025-03-14 21:18:43 [info ] starting compilation run...
82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 23/28 [13:35<03:04, 36.80s/it]2025-03-14 21:33:42.057559: W torch_xla/csrc/runtime/pjrt_computation_client.cc:667] Failed to deserialize executable: INTERNAL: TfrtTpuExecutable proto deserialization failed while parsing core program!
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.28s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.26s/it]
2025-03-14 21:36:38 [info ] compilation took 1079.3314765350078 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
2025-03-14 21:36:38 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
2025-03-14 21:36:38 [info ] compilation took 1081.89390801001 sec.
2025-03-14 21:36:39 [info ] starting inference run...
2025-03-14 21:36:39 [info ] compilation took 1077.1543154849933 sec.
2025-03-14 21:36:39 [info ] compilation took 1075.7239800530078 sec.
2025-03-14 21:36:39 [info ] starting inference run...
2025-03-14 21:36:40 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:22<00:00, 35.10s/it]
2025-03-14 21:36:50 [info ] compilation took 1088.1632604240003 sec.
2025-03-14 21:36:50 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:28<00:00, 35.32s/it]
2025-03-14 21:36:55 [info ] compilation took 1096.8027802760043 sec.
2025-03-14 21:36:56 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:59<00:00, 36.40s/it]
2025-03-14 21:37:08 [info ] compilation took 1113.8591305939917 sec.
2025-03-14 21:37:08 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:55<00:00, 36.26s/it]
2025-03-14 21:37:22 [info ] compilation took 1120.5590810020076 sec.
2025-03-14 21:37:22 [info ] starting inference run...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.00it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.98it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.08it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00, 2.82it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:08<00:00, 3.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.22it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.09it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:11<00:00, 2.41it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00, 4.50it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.10it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.27it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 4.80it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00, 5.39it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.67it/s]
29%|█████████████████████████████████████████████████████████████████████████████▍ | 8/28 [00:01<00:03, 6.08it/s]/home/jfacevedo_google_com/diffusers/src/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast
images = (images * 255).round().astype("uint8")
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.82it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.93it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.98it/s]
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.03it/s]2025-03-14 21:38:32 [info ] inference time: 5.962021178987925
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.2685392687970305 sec.
2025-03-14 21:38:32 [info ] avg. inference over 5 iterations took 7.402720856998348 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.01it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.89it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.06it/s]
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.01it/s]2025-03-14 21:38:38 [info ] inference time: 5.950578948002658
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.87it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.09it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.00it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.86it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.99it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.05it/s]
2025-03-14 21:38:43 [info ] avg. inference over 5 iterations took 6.763298449796276 sec.
71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊ | 20/28 [00:03<00:01, 6.04it/s]2025-03-14 21:38:44 [info ] inference time: 5.949129879008979
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.92it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.10it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
39%|██████████████████████████████████████████████████████████████████████████████████████████████████████████ | 11/28 [00:01<00:02, 5.98it/s]2025-03-14 21:38:46 [info ] avg. inference over 5 iterations took 7.221068455604836 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.96it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.08it/s]
93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 26/28 [00:04<00:00, 5.92it/s]2025-03-14 21:38:50 [info ] inference time: 5.954778069004533
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.90it/s]
11%|█████████████████████████████ | 3/28 [00:00<00:04, 6.03it/s]2025-03-14 21:38:50 [info ] avg. inference over 5 iterations took 6.05970350120042 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 6.02it/s]
32%|███████████████████████████████████████████████████████████████████████████████████████ | 9/28 [00:01<00:03, 5.99it/s]2025-03-14 21:38:51 [info ] avg. inference over 5 iterations took 6.018543455796316 sec.
54%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋ | 15/28 [00:02<00:02, 6.00it/s]2025-03-14 21:38:52 [info ] avg. inference over 5 iterations took 5.9609976705978625 sec.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████���█████████████████████████████████████████████████████████| 28/28 [00:04<00:00, 5.97it/s]
2025-03-14 21:38:56 [info ] inference time: 5.944058528999449
2025-03-14 21:38:56 [info ] avg. inference over 5 iterations took 5.952113320800708 sec.
2025-03-14 21:38:56 [info ] saved metric information as /tmp/metrics_report.txt