docs/diffusion/ddpm/experiment.html
This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/celebA folder.
The paper had used a exponential moving average of the model with a decay of 0.9999. We have skipped this for simplicity.
20fromtypingimportList2122importtorchvision23fromPILimportImage2425importtorch26importtorch.utils.data27fromlabmlimportlab,tracker,experiment,monit28fromlabml.configsimportBaseConfigs,option29fromlabml\_nn.diffusion.ddpmimportDenoiseDiffusion30fromlabml\_nn.diffusion.ddpm.unetimportUNet31fromlabml\_nn.helpers.deviceimportDeviceConfigs
34classConfigs(BaseConfigs):
Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.
41device:torch.device=DeviceConfigs()
U-Net model for ϵθ(xt,t)
44eps\_model:UNet
46diffusion:DenoiseDiffusion
Number of channels in the image. 3 for RGB.
49image\_channels:int=3
Image size
51image\_size:int=32
Number of channels in the initial feature map
53n\_channels:int=64
The list of channel numbers at each resolution. The number of channels is channel_multipliers[i] * n_channels
56channel\_multipliers:List[int]=[1,2,2,4]
The list of booleans that indicate whether to use attention at each resolution
58is\_attention:List[int]=[False,False,False,True]
Number of time steps T
61n\_steps:int=1\_000
Batch size
63batch\_size:int=64
Number of samples to generate
65n\_samples:int=16
Learning rate
67learning\_rate:float=2e-5
Number of training epochs
70epochs:int=1\_000
Dataset
73dataset:torch.utils.data.Dataset
Dataloader
75data\_loader:torch.utils.data.DataLoader
Adam optimizer
78optimizer:torch.optim.Adam
80definit(self):
Create ϵθ(xt,t) model
82self.eps\_model=UNet(83image\_channels=self.image\_channels,84n\_channels=self.n\_channels,85ch\_mults=self.channel\_multipliers,86is\_attn=self.is\_attention,87).to(self.device)
Create DDPM class
90self.diffusion=DenoiseDiffusion(91eps\_model=self.eps\_model,92n\_steps=self.n\_steps,93device=self.device,94)
Create dataloader
97self.data\_loader=torch.utils.data.DataLoader(self.dataset,self.batch\_size,shuffle=True,pin\_memory=True)
Create optimizer
99self.optimizer=torch.optim.Adam(self.eps\_model.parameters(),lr=self.learning\_rate)
Image logging
102tracker.set\_image("sample",True)
104defsample(self):
108withtorch.no\_grad():
xT∼p(xT)=N(xT;0,I)
110x=torch.randn([self.n\_samples,self.image\_channels,self.image\_size,self.image\_size],111device=self.device)
Remove noise for T steps
114fort\_inmonit.iterate('Sample',self.n\_steps):
t
116t=self.n\_steps-t\_-1
Sample from pθ(xt−1∣xt)
118x=self.diffusion.p\_sample(x,x.new\_full((self.n\_samples,),t,dtype=torch.long))
Log samples
121tracker.save('sample',x)
123deftrain(self):
Iterate through the dataset
129fordatainmonit.iterate('Train',self.data\_loader):
Increment global step
131tracker.add\_global\_step()
Move data to device
133data=data.to(self.device)
Make the gradients zero
136self.optimizer.zero\_grad()
Calculate loss
138loss=self.diffusion.loss(data)
Compute gradients
140loss.backward()
Take an optimization step
142self.optimizer.step()
Track the loss
144tracker.save('loss',loss)
146defrun(self):
150for\_inmonit.loop(self.epochs):
Train the model
152self.train()
Sample some images
154self.sample()
New line in the console
156tracker.new\_line()
159classCelebADataset(torch.utils.data.Dataset):
164def\_\_init\_\_(self,image\_size:int):165super().\_\_init\_\_()
CelebA images folder
168folder=lab.get\_data\_path()/'celebA'
List of files
170self.\_files=[pforpinfolder.glob(f'\*\*/\*.jpg')]
Transformations to resize the image and convert to tensor
173self.\_transform=torchvision.transforms.Compose([174torchvision.transforms.Resize(image\_size),175torchvision.transforms.ToTensor(),176])
Size of the dataset
178def\_\_len\_\_(self):
182returnlen(self.\_files)
Get an image
184def\_\_getitem\_\_(self,index:int):
188img=Image.open(self.\_files[index])189returnself.\_transform(img)
Create CelebA dataset
192@option(Configs.dataset,'CelebA')193defceleb\_dataset(c:Configs):
197returnCelebADataset(c.image\_size)
200classMNISTDataset(torchvision.datasets.MNIST):
205def\_\_init\_\_(self,image\_size):206transform=torchvision.transforms.Compose([207torchvision.transforms.Resize(image\_size),208torchvision.transforms.ToTensor(),209])210211super().\_\_init\_\_(str(lab.get\_data\_path()),train=True,download=True,transform=transform)
213def\_\_getitem\_\_(self,item):214returnsuper().\_\_getitem\_\_(item)[0]
Create MNIST dataset
217@option(Configs.dataset,'MNIST')218defmnist\_dataset(c:Configs):
222returnMNISTDataset(c.image\_size)
225defmain():
Create experiment
227experiment.create(name='diffuse',writers={'screen','labml'})
Create configurations
230configs=Configs()
Set configurations. You can override the defaults by passing the values in the dictionary.
233experiment.configs(configs,{234'dataset':'CelebA',# 'MNIST'235'image\_channels':3,# 1,236'epochs':100,# 5,237})
Initialize
240configs.init()
Set models for saving and loading
243experiment.add\_pytorch\_models({'eps\_model':configs.eps\_model})
Start and run the training loop
246withexperiment.start():247configs.run()
251if\_\_name\_\_=='\_\_main\_\_':252main()