src/lightning_fabric/README.md
Fabric is the fast and lightweight way to scale PyTorch models without boilerplate
Run on any device at any scale with expert-level control over PyTorch training loop and scaling strategy. You can even write your own Trainer.
Fabric is designed for the most complex models like foundation model scaling, LLMs, diffusion, transformers, reinforcement learning, active learning. Of any size.
<table> <tr> <th>What to change</th> <th>Resulting Fabric Code (copy me!)</th> </tr> <tr> <td> <sub>+ import lightning as L
import torch; import torchvision as tv
dataset = tv.datasets.CIFAR10("data", download=True,
train=True,
transform=tv.transforms.ToTensor())
+ fabric = L.Fabric()
+ fabric.launch()
model = tv.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
- device = "cuda" if torch.cuda.is_available() else "cpu"
- model.to(device)
+ model, optimizer = fabric.setup(model, optimizer)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
+ dataloader = fabric.setup_dataloaders(dataloader)
model.train()
num_epochs = 10
for epoch in range(num_epochs):
for batch in dataloader:
inputs, labels = batch
- inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, labels)
- loss.backward()
+ fabric.backward(loss)
optimizer.step()
import lightning as L
import torch; import torchvision as tv
dataset = tv.datasets.CIFAR10("data", download=True,
train=True,
transform=tv.transforms.ToTensor())
fabric = L.Fabric()
fabric.launch()
model = tv.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
model, optimizer = fabric.setup(model, optimizer)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
dataloader = fabric.setup_dataloaders(dataloader)
model.train()
num_epochs = 10
for epoch in range(num_epochs):
for batch in dataloader:
inputs, labels = batch
optimizer.zero_grad()
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, labels)
fabric.backward(loss)
optimizer.step()
# Use your available hardware
# no code changes needed
fabric = Fabric()
# Run on GPUs (CUDA or MPS)
fabric = Fabric(accelerator="gpu")
# 8 GPUs
fabric = Fabric(accelerator="gpu", devices=8)
# 256 GPUs, multi-node
fabric = Fabric(accelerator="gpu", devices=8, num_nodes=32)
# Run on TPUs
fabric = Fabric(accelerator="tpu")
# Use state-of-the-art distributed training techniques
fabric = Fabric(strategy="ddp")
fabric = Fabric(strategy="deepspeed")
fabric = Fabric(strategy="fsdp")
# Switch the precision
fabric = Fabric(precision="16-mixed")
fabric = Fabric(precision="64")
# no more of this!
- model.to(device)
- batch.to(device)
import lightning as L
class MyCustomTrainer:
def __init__(self, accelerator="auto", strategy="auto", devices="auto", precision="32-true"):
self.fabric = L.Fabric(accelerator=accelerator, strategy=strategy, devices=devices, precision=precision)
def fit(self, model, optimizer, dataloader, max_epochs):
self.fabric.launch()
model, optimizer = self.fabric.setup(model, optimizer)
dataloader = self.fabric.setup_dataloaders(dataloader)
model.train()
for epoch in range(max_epochs):
for batch in dataloader:
input, target = batch
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
self.fabric.backward(loss)
optimizer.step()
You can find a more extensive example in our examples
</details>Lightning is rigorously tested across multiple CPUs and GPUs and against major Python and PyTorch versions.
| System / PyTorch ver. | 1.12 | 1.13 | 2.0 | 2.1 |
|---|---|---|---|---|
| Linux py3.9 [GPUs] | ||||
| Linux (multiple Python versions) | ||||
| OSX (multiple Python versions) | ||||
| Windows (multiple Python versions) |
TIP: We strongly recommend creating a virtual environment first. Don’t know what this is? Follow our beginner guide here.
pip install -U lightning
Create the Fabric object at the beginning of your training code.
import Lightning as L
fabric = L.Fabric()
Call setup() on each model and optimizer pair and setup_dataloaders() on all your data loaders.
model, optimizer = fabric.setup(model, optimizer)
dataloader = fabric.setup_dataloaders(dataloader)
Remove all .to and .cuda calls -> Fabric will take care of it.
- model.to(device)
- batch.to(device)
Replace loss.backward() by fabric.backward(loss).
- loss.backward()
+ fabric.backward(loss)
Run the script from the terminal with
lightning run model path/to/train.py
or use the launch() method in a notebook. Learn more about launching distributed training.
Yes :) Fabric works with PyTorch LightningModules and Callbacks, so you can choose how to structure your code and reuse existing models and callbacks as you wish. Read more here.
If you have any questions please: