Back to Annotated Deep Learning Paper Implementations

Denoising Diffusion Probabilistic Models (DDPM) training

docs/diffusion/ddpm/experiment.html

latest7.5 KB
Original Source

homediffusionddpm

View code on Github

#

Denoising Diffusion Probabilistic Models (DDPM) training

This trains a DDPM based model on CelebA HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/celebA folder.

The paper had used a exponential moving average of the model with a decay of 0.9999. We have skipped this for simplicity.

20fromtypingimportList2122importtorchvision23fromPILimportImage2425importtorch26importtorch.utils.data27fromlabmlimportlab,tracker,experiment,monit28fromlabml.configsimportBaseConfigs,option29fromlabml\_nn.diffusion.ddpmimportDenoiseDiffusion30fromlabml\_nn.diffusion.ddpm.unetimportUNet31fromlabml\_nn.helpers.deviceimportDeviceConfigs

#

Configurations

34classConfigs(BaseConfigs):

#

Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.

41device:torch.device=DeviceConfigs()

#

U-Net model for ϵθ​(xt​,t)

44eps\_model:UNet

#

DDPM algorithm

46diffusion:DenoiseDiffusion

#

Number of channels in the image. 3 for RGB.

49image\_channels:int=3

#

Image size

51image\_size:int=32

#

Number of channels in the initial feature map

53n\_channels:int=64

#

The list of channel numbers at each resolution. The number of channels is channel_multipliers[i] * n_channels

56channel\_multipliers:List[int]=[1,2,2,4]

#

The list of booleans that indicate whether to use attention at each resolution

58is\_attention:List[int]=[False,False,False,True]

#

Number of time steps T

61n\_steps:int=1\_000

#

Batch size

63batch\_size:int=64

#

Number of samples to generate

65n\_samples:int=16

#

Learning rate

67learning\_rate:float=2e-5

#

Number of training epochs

70epochs:int=1\_000

#

Dataset

73dataset:torch.utils.data.Dataset

#

Dataloader

75data\_loader:torch.utils.data.DataLoader

#

Adam optimizer

78optimizer:torch.optim.Adam

#

80definit(self):

#

Create ϵθ​(xt​,t) model

82self.eps\_model=UNet(83image\_channels=self.image\_channels,84n\_channels=self.n\_channels,85ch\_mults=self.channel\_multipliers,86is\_attn=self.is\_attention,87).to(self.device)

#

Create DDPM class

90self.diffusion=DenoiseDiffusion(91eps\_model=self.eps\_model,92n\_steps=self.n\_steps,93device=self.device,94)

#

Create dataloader

97self.data\_loader=torch.utils.data.DataLoader(self.dataset,self.batch\_size,shuffle=True,pin\_memory=True)

#

Create optimizer

99self.optimizer=torch.optim.Adam(self.eps\_model.parameters(),lr=self.learning\_rate)

#

Image logging

102tracker.set\_image("sample",True)

#

Sample images

104defsample(self):

#

108withtorch.no\_grad():

#

xT​∼p(xT​)=N(xT​;0,I)

110x=torch.randn([self.n\_samples,self.image\_channels,self.image\_size,self.image\_size],111device=self.device)

#

Remove noise for T steps

114fort\_inmonit.iterate('Sample',self.n\_steps):

#

t

116t=self.n\_steps-t\_-1

#

Sample from pθ​(xt−1​∣xt​)

118x=self.diffusion.p\_sample(x,x.new\_full((self.n\_samples,),t,dtype=torch.long))

#

Log samples

121tracker.save('sample',x)

#

Train

123deftrain(self):

#

Iterate through the dataset

129fordatainmonit.iterate('Train',self.data\_loader):

#

Increment global step

131tracker.add\_global\_step()

#

Move data to device

133data=data.to(self.device)

#

Make the gradients zero

136self.optimizer.zero\_grad()

#

Calculate loss

138loss=self.diffusion.loss(data)

#

Compute gradients

140loss.backward()

#

Take an optimization step

142self.optimizer.step()

#

Track the loss

144tracker.save('loss',loss)

#

Training loop

146defrun(self):

#

150for\_inmonit.loop(self.epochs):

#

Train the model

152self.train()

#

Sample some images

154self.sample()

#

New line in the console

156tracker.new\_line()

#

CelebA HQ dataset

159classCelebADataset(torch.utils.data.Dataset):

#

164def\_\_init\_\_(self,image\_size:int):165super().\_\_init\_\_()

#

CelebA images folder

168folder=lab.get\_data\_path()/'celebA'

#

List of files

170self.\_files=[pforpinfolder.glob(f'\*\*/\*.jpg')]

#

Transformations to resize the image and convert to tensor

173self.\_transform=torchvision.transforms.Compose([174torchvision.transforms.Resize(image\_size),175torchvision.transforms.ToTensor(),176])

#

Size of the dataset

178def\_\_len\_\_(self):

#

182returnlen(self.\_files)

#

Get an image

184def\_\_getitem\_\_(self,index:int):

#

188img=Image.open(self.\_files[index])189returnself.\_transform(img)

#

Create CelebA dataset

192@option(Configs.dataset,'CelebA')193defceleb\_dataset(c:Configs):

#

197returnCelebADataset(c.image\_size)

#

MNIST dataset

200classMNISTDataset(torchvision.datasets.MNIST):

#

205def\_\_init\_\_(self,image\_size):206transform=torchvision.transforms.Compose([207torchvision.transforms.Resize(image\_size),208torchvision.transforms.ToTensor(),209])210211super().\_\_init\_\_(str(lab.get\_data\_path()),train=True,download=True,transform=transform)

#

213def\_\_getitem\_\_(self,item):214returnsuper().\_\_getitem\_\_(item)[0]

#

Create MNIST dataset

217@option(Configs.dataset,'MNIST')218defmnist\_dataset(c:Configs):

#

222returnMNISTDataset(c.image\_size)

#

225defmain():

#

Create experiment

227experiment.create(name='diffuse',writers={'screen','labml'})

#

Create configurations

230configs=Configs()

#

Set configurations. You can override the defaults by passing the values in the dictionary.

233experiment.configs(configs,{234'dataset':'CelebA',# 'MNIST'235'image\_channels':3,# 1,236'epochs':100,# 5,237})

#

Initialize

240configs.init()

#

Set models for saving and loading

243experiment.add\_pytorch\_models({'eps\_model':configs.eps\_model})

#

Start and run the training loop

246withexperiment.start():247configs.run()

#

251if\_\_name\_\_=='\_\_main\_\_':252main()

labml.ai