Back to Diffusers

Generating images using Flux and PyTorch/XLA

examples/research_projects/pytorch_xla/inference/flux/README.md

0.38.063.0 KB
Original Source

Generating images using Flux and PyTorch/XLA

The flux_inference script shows how to do image generation using Flux on TPU devices using PyTorch/XLA. It uses the pallas kernel for flash attention for faster generation using custom flash block sizes for better performance on Trillium TPU versions. No other TPU types have been tested.

Create TPU

To create a TPU on Google Cloud, follow this guide

Setup TPU environment

SSH into the VM and install Pytorch, Pytorch/XLA

bash
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

Verify that PyTorch and PyTorch/XLA were installed correctly:

bash
python3 -c "import torch; import torch_xla;"

Clone the diffusers repo and install dependencies

bash
git clone https://github.com/huggingface/diffusers.git
cd diffusers
pip install transformers accelerate sentencepiece structlog
pip install .
cd examples/research_projects/pytorch_xla/inference/flux/

Run the inference job

Authenticate

Gated Model

As the model is gated, before using it with diffusers you first need to go to the FLUX.1 [dev] Hugging Face page, fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:

bash
hf auth login

Then run:

bash
python flux_inference.py

The script loads the text encoders onto the CPU and the Flux transformer and VAE models onto the TPU. The first time the script runs, the compilation time is longer, while the cache stores the compiled programs. On subsequent runs, compilation is much faster and the subsequent passes being the fastest.

On a Trillium v6e-4, you should expect ~6 sec / 4 images or 1.5 sec / image (as devices run generation in parallel).

Note: flux_inference.py uses xmp.spawn (one process per chip) and requires the full model to fit on a single chip. If you run into OOM errors (e.g., on v5e with 16GB HBM per chip), use the SPMD version instead — see below.

SPMD version (for v5e-8 and similar)

On TPU configurations where a single chip cannot hold the full FLUX transformer (~16GB in bf16), use flux_inference_spmd.py. This script uses PyTorch/XLA SPMD to shard the transformer across multiple chips using a (data, model) mesh — 4-way model parallel so each chip holds ~4GB of weights, with the remaining chips for data parallelism.

bash
python flux_inference_spmd.py --schnell

Key differences from flux_inference.py:

  • Single-process SPMD instead of multi-process xmp.spawn — the XLA compiler handles all collective communication transparently.
  • Transformer weights are sharded across the "model" mesh axis using xs.mark_sharding.
  • VAE lives on CPU, moved to XLA only for decode (then moved back), since the transformer stays on device throughout.
  • Text encoding runs on CPU before loading the transformer.

On a v5litepod-8 (v5e, 8 chips, 16GB HBM each) with FLUX.1-schnell, expect ~1.76 sec/image at steady state (after compilation):

2026-04-15 02:24:30 [info     ] SPMD mesh: (2, 4), axes: ('data', 'model'), devices: 8
2026-04-15 02:24:30 [info     ] encoding prompt on CPU...
2026-04-15 02:26:20 [info     ] loading VAE on CPU...
2026-04-15 02:26:20 [info     ] loading flux transformer from black-forest-labs/FLUX.1-schnell
2026-04-15 02:27:22 [info     ] starting compilation run...
2026-04-15 02:52:55 [info     ] compilation took 1533.4575625509997 sec.
2026-04-15 02:52:56 [info     ] starting inference run...
2026-04-15 02:56:11 [info     ] inference time: 195.74092420299985
2026-04-15 02:56:13 [info     ] inference time: 1.7625778899996476
2026-04-15 02:56:13 [info     ] avg. inference over 2 iterations took 98.75175104649975 sec.

The first inference iteration includes VAE compilation (~195s). The second iteration shows the true steady-state speed (~1.76s).

v6e-4 results (original flux_inference.py)

bash
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.06it/s]
Loading pipeline components...:  60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                | 3/5 [00:00<00:00,  6.80it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  6.28it/s]
2025-03-14 21:17:53 [info     ] loading flux from black-forest-labs/FLUX.1-dev
2025-03-14 21:17:53 [info     ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...:   0%|                                                                                                                                                                                                                                                        | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:53 [info     ] loading flux from black-forest-labs/FLUX.1-dev
2025-03-14 21:17:53 [info     ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...:   0%|                                                                                                                                                                                                                                                        | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:54 [info     ] loading flux from black-forest-labs/FLUX.1-dev
2025-03-14 21:17:54 [info     ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.66it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  4.48it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.32it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.69it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.74it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.10it/s]
2025-03-14 21:17:56 [info     ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...:   0%|                                                                                                                                                                                                                                                        | 0/3 [00:00<?, ?it/s]2025-03-14 21:17:56 [info     ] loading flux from black-forest-labs/FLUX.1-dev
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00,  3.55it/s]
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:02<00:00,  1.46it/s]
2025-03-14 21:18:34 [info     ] starting compilation run...   
2025-03-14 21:18:37 [info     ] starting compilation run...   
2025-03-14 21:18:38 [info     ] starting compilation run...   
2025-03-14 21:18:39 [info     ] starting compilation run...   
2025-03-14 21:18:41 [info     ] starting compilation run...   
2025-03-14 21:18:41 [info     ] starting compilation run...   
2025-03-14 21:18:42 [info     ] starting compilation run...   
2025-03-14 21:18:43 [info     ] starting compilation run...   
 82%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                | 23/28 [13:35<03:04, 36.80s/it]2025-03-14 21:33:42.057559: W torch_xla/csrc/runtime/pjrt_computation_client.cc:667] Failed to deserialize executable: INTERNAL: TfrtTpuExecutable proto deserialization failed while parsing core program!
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.28s/it]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:27<00:00, 35.26s/it]
2025-03-14 21:36:38 [info     ] compilation took 1079.3314765350078 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
2025-03-14 21:36:38 [info     ] starting inference run...     
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:12<00:00, 34.73s/it]
2025-03-14 21:36:38 [info     ] compilation took 1081.89390801001 sec.
2025-03-14 21:36:39 [info     ] starting inference run...     
2025-03-14 21:36:39 [info     ] compilation took 1077.1543154849933 sec.
2025-03-14 21:36:39 [info     ] compilation took 1075.7239800530078 sec.
2025-03-14 21:36:39 [info     ] starting inference run...     
2025-03-14 21:36:40 [info     ] starting inference run...     
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:22<00:00, 35.10s/it]
2025-03-14 21:36:50 [info     ] compilation took 1088.1632604240003 sec.
2025-03-14 21:36:50 [info     ] starting inference run...     
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:28<00:00, 35.32s/it]
2025-03-14 21:36:55 [info     ] compilation took 1096.8027802760043 sec.
2025-03-14 21:36:56 [info     ] starting inference run...     
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:59<00:00, 36.40s/it]
2025-03-14 21:37:08 [info     ] compilation took 1113.8591305939917 sec.
2025-03-14 21:37:08 [info     ] starting inference run...     
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [16:55<00:00, 36.26s/it]
2025-03-14 21:37:22 [info     ] compilation took 1120.5590810020076 sec.
2025-03-14 21:37:22 [info     ] starting inference run...     
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00,  5.00it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00,  2.98it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00,  4.08it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:09<00:00,  2.82it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:08<00:00,  3.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00,  4.22it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00,  4.09it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:11<00:00,  2.41it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:06<00:00,  4.50it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00,  5.10it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00,  5.27it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00,  4.80it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00,  5.39it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:05<00:00,  5.39it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.67it/s]
 29%|█████████████████████████████████████████████████████████████████████████████▍                                                                                                                                                                                                 | 8/28 [00:01<00:03,  6.08it/s]/home/jfacevedo_google_com/diffusers/src/diffusers/image_processor.py:147: RuntimeWarning: invalid value encountered in cast
  images = (images * 255).round().astype("uint8")
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.82it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.93it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.02it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.02it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.98it/s]
 71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                             | 20/28 [00:03<00:01,  6.03it/s]2025-03-14 21:38:32 [info     ] inference time: 5.962021178987925
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.89it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.09it/s]
2025-03-14 21:38:32 [info     ] avg. inference over 5 iterations took 7.2685392687970305 sec.
2025-03-14 21:38:32 [info     ] avg. inference over 5 iterations took 7.402720856998348 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.01it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.89it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.96it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.06it/s]
 71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                             | 20/28 [00:03<00:01,  6.01it/s]2025-03-14 21:38:38 [info     ] inference time: 5.950578948002658
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.87it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.09it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.00it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.86it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.99it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.05it/s]
2025-03-14 21:38:43 [info     ] avg. inference over 5 iterations took 6.763298449796276 sec.
 71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                                                                             | 20/28 [00:03<00:01,  6.04it/s]2025-03-14 21:38:44 [info     ] inference time: 5.949129879008979
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.92it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.10it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.02it/s]
 39%|██████████████████████████████████████████████████████████████████████████████████████████████████████████                                                                                                                                                                    | 11/28 [00:01<00:02,  5.98it/s]2025-03-14 21:38:46 [info     ] avg. inference over 5 iterations took 7.221068455604836 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.96it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.08it/s]
 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                   | 26/28 [00:04<00:00,  5.92it/s]2025-03-14 21:38:50 [info     ] inference time: 5.954778069004533
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.90it/s]
 11%|█████████████████████████████                                                                                                                                                                                                                                                  | 3/28 [00:00<00:04,  6.03it/s]2025-03-14 21:38:50 [info     ] avg. inference over 5 iterations took 6.05970350120042 sec.
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  6.02it/s]
 32%|███████████████████████████████████████████████████████████████████████████████████████                                                                                                                                                                                        | 9/28 [00:01<00:03,  5.99it/s]2025-03-14 21:38:51 [info     ] avg. inference over 5 iterations took 6.018543455796316 sec.
 54%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                                                             | 15/28 [00:02<00:02,  6.00it/s]2025-03-14 21:38:52 [info     ] avg. inference over 5 iterations took 5.9609976705978625 sec.
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████���█████████████████████████████████████████████████████████| 28/28 [00:04<00:00,  5.97it/s]
2025-03-14 21:38:56 [info     ] inference time: 5.944058528999449
2025-03-14 21:38:56 [info     ] avg. inference over 5 iterations took 5.952113320800708 sec.
2025-03-14 21:38:56 [info     ] saved metric information as /tmp/metrics_report.txt