Back to Annotated Deep Learning Paper Implementations

StyleGAN 2 Model Training

docs/gan/stylegan/experiment.html

latest16.9 KB
Original Source

homeganstylegan

View code on Github

#

StyleGAN 2 Model Training

This is the training code for StyleGAN 2 model.

These are 64×64 images generated after training for about 80K steps.

Our implementation is a minimalistic StyleGAN 2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.

Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model for large resolutions (128+). If you want training code with fp16 and DDP take a look at lucidrains/stylegan2-pytorch.

We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan folder.

31importmath32frompathlibimportPath33fromtypingimportIterator,Tuple3435importtorchvision36fromPILimportImage3738importtorch39importtorch.utils.data40fromlabmlimporttracker,lab,monit,experiment41fromlabml.configsimportBaseConfigs42fromlabml\_nn.gan.styleganimportDiscriminator,Generator,MappingNetwork,GradientPenalty,PathLengthPenalty43fromlabml\_nn.gan.wassersteinimportDiscriminatorLoss,GeneratorLoss44fromlabml\_nn.helpers.deviceimportDeviceConfigs45fromlabml\_nn.helpers.trainerimportModeState46fromlabml\_nn.utilsimportcycle\_dataloader

#

Dataset

This loads the training dataset and resize it to the give image size.

49classDataset(torch.utils.data.Dataset):

#

  • path path to the folder containing the images
  • image_size size of the image
56def\_\_init\_\_(self,path:str,image\_size:int):

#

61super().\_\_init\_\_()

#

Get the paths of all jpg files

64self.paths=[pforpinPath(path).glob(f'\*\*/\*.jpg')]

#

Transformation

67self.transform=torchvision.transforms.Compose([

#

Resize the image

69torchvision.transforms.Resize(image\_size),

#

Convert to PyTorch tensor

71torchvision.transforms.ToTensor(),72])

#

Number of images

74def\_\_len\_\_(self):

#

76returnlen(self.paths)

#

Get the the index -th image

78def\_\_getitem\_\_(self,index):

#

80path=self.paths[index]81img=Image.open(path)82returnself.transform(img)

#

Configurations

85classConfigs(BaseConfigs):

#

Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.

93device:torch.device=DeviceConfigs()

#

StyleGAN2 Discriminator

96discriminator:Discriminator

#

StyleGAN2 Generator

98generator:Generator

#

Mapping network

100mapping\_network:MappingNetwork

#

Discriminator and generator loss functions. We use Wasserstein loss

104discriminator\_loss:DiscriminatorLoss105generator\_loss:GeneratorLoss

#

Optimizers

108generator\_optimizer:torch.optim.Adam109discriminator\_optimizer:torch.optim.Adam110mapping\_network\_optimizer:torch.optim.Adam

#

Gradient Penalty Regularization Loss

113gradient\_penalty=GradientPenalty()

#

Gradient penalty coefficient γ

115gradient\_penalty\_coefficient:float=10.

#

Path length penalty

118path\_length\_penalty:PathLengthPenalty

#

Data loader

121loader:Iterator

#

Batch size

124batch\_size:int=32

#

Dimensionality of z and w

126d\_latent:int=512

#

Height/width of the image

128image\_size:int=32

#

Number of layers in the mapping network

130mapping\_network\_layers:int=8

#

Generator & Discriminator learning rate

132learning\_rate:float=1e-3

#

Mapping network learning rate (100× lower than the others)

134mapping\_network\_learning\_rate:float=1e-5

#

Number of steps to accumulate gradients on. Use this to increase the effective batch size.

136gradient\_accumulate\_steps:int=1

#

β1​ and β2​ for Adam optimizer

138adam\_betas:Tuple[float,float]=(0.0,0.99)

#

Probability of mixing styles

140style\_mixing\_prob:float=0.9

#

Total number of training steps

143training\_steps:int=150\_000

#

Number of blocks in the generator (calculated based on image resolution)

146n\_gen\_blocks:int

#

Lazy regularization

Instead of calculating the regularization losses, the paper proposes lazy regularization where the regularization terms are calculated once in a while. This improves the training efficiency a lot.

#

The interval at which to compute gradient penalty

154lazy\_gradient\_penalty\_interval:int=4

#

Path length penalty calculation interval

156lazy\_path\_penalty\_interval:int=32

#

Skip calculating path length penalty during the initial phase of training

158lazy\_path\_penalty\_after:int=5\_000

#

How often to log generated images

161log\_generated\_interval:int=500

#

How often to save model checkpoints

163save\_checkpoint\_interval:int=2\_000

#

Training mode state for logging activations

166mode:ModeState

#

We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan folder.

173dataset\_path:str=str(lab.get\_data\_path()/'stylegan2')

#

Initialize

175definit(self):

#

Create dataset

180dataset=Dataset(self.dataset\_path,self.image\_size)

#

Create data loader

182dataloader=torch.utils.data.DataLoader(dataset,batch\_size=self.batch\_size,num\_workers=8,183shuffle=True,drop\_last=True,pin\_memory=True)

#

Continuous cyclic loader

185self.loader=cycle\_dataloader(dataloader)

#

log2​ of image resolution

188log\_resolution=int(math.log2(self.image\_size))

#

Create discriminator and generator

191self.discriminator=Discriminator(log\_resolution).to(self.device)192self.generator=Generator(log\_resolution,self.d\_latent).to(self.device)

#

Get number of generator blocks for creating style and noise inputs

194self.n\_gen\_blocks=self.generator.n\_blocks

#

Create mapping network

196self.mapping\_network=MappingNetwork(self.d\_latent,self.mapping\_network\_layers).to(self.device)

#

Create path length penalty loss

198self.path\_length\_penalty=PathLengthPenalty(0.99).to(self.device)

#

Discriminator and generator losses

201self.discriminator\_loss=DiscriminatorLoss().to(self.device)202self.generator\_loss=GeneratorLoss().to(self.device)

#

Create optimizers

205self.discriminator\_optimizer=torch.optim.Adam(206self.discriminator.parameters(),207lr=self.learning\_rate,betas=self.adam\_betas208)209self.generator\_optimizer=torch.optim.Adam(210self.generator.parameters(),211lr=self.learning\_rate,betas=self.adam\_betas212)213self.mapping\_network\_optimizer=torch.optim.Adam(214self.mapping\_network.parameters(),215lr=self.mapping\_network\_learning\_rate,betas=self.adam\_betas216)

#

Set tracker configurations

219tracker.set\_image("generated",True)

#

Sample w

This samples z randomly and get w from the mapping network.

We also apply style mixing sometimes where we generate two latent variables z1​ and z2​ and get corresponding w1​ and w2​. Then we randomly sample a cross-over point and apply w1​ to the generator blocks before the cross-over point and w2​ to the blocks after.

221defget\_w(self,batch\_size:int):

#

Mix styles

235iftorch.rand(()).item()\<self.style\_mixing\_prob:

#

Random cross-over point

237cross\_over\_point=int(torch.rand(()).item()\*self.n\_gen\_blocks)

#

Sample z1​ and z2​

239z2=torch.randn(batch\_size,self.d\_latent).to(self.device)240z1=torch.randn(batch\_size,self.d\_latent).to(self.device)

#

Get w1​ and w2​

242w1=self.mapping\_network(z1)243w2=self.mapping\_network(z2)

#

Expand w1​ and w2​ for the generator blocks and concatenate

245w1=w1[None,:,:].expand(cross\_over\_point,-1,-1)246w2=w2[None,:,:].expand(self.n\_gen\_blocks-cross\_over\_point,-1,-1)247returntorch.cat((w1,w2),dim=0)

#

Without mixing

249else:

#

Sample z and z

251z=torch.randn(batch\_size,self.d\_latent).to(self.device)

#

Get w and w

253w=self.mapping\_network(z)

#

Expand w for the generator blocks

255returnw[None,:,:].expand(self.n\_gen\_blocks,-1,-1)

#

Generate noise

This generates noise for each generator block

257defget\_noise(self,batch\_size:int):

#

List to store noise

264noise=[]

#

Noise resolution starts from 4

266resolution=4

#

Generate noise for each generator block

269foriinrange(self.n\_gen\_blocks):

#

The first block has only one 3×3 convolution

271ifi==0:272n1=None

#

Generate noise to add after the first convolution layer

274else:275n1=torch.randn(batch\_size,1,resolution,resolution,device=self.device)

#

Generate noise to add after the second convolution layer

277n2=torch.randn(batch\_size,1,resolution,resolution,device=self.device)

#

Add noise tensors to the list

280noise.append((n1,n2))

#

Next block has 2× resolution

283resolution\*=2

#

Return noise tensors

286returnnoise

#

Generate images

This generate images using the generator

288defgenerate\_images(self,batch\_size:int):

#

Get w

296w=self.get\_w(batch\_size)

#

Get noise

298noise=self.get\_noise(batch\_size)

#

Generate images

301images=self.generator(w,noise)

#

Return images and w

304returnimages,w

#

Training Step

306defstep(self,idx:int):

#

Train the discriminator

312withmonit.section('Discriminator'):

#

Reset gradients

314self.discriminator\_optimizer.zero\_grad()

#

Accumulate gradients for gradient_accumulate_steps

317foriinrange(self.gradient\_accumulate\_steps):

#

Sample images from generator

319generated\_images,\_=self.generate\_images(self.batch\_size)

#

Discriminator classification for generated images

321fake\_output=self.discriminator(generated\_images.detach())

#

Get real images from the data loader

324real\_images=next(self.loader).to(self.device)

#

We need to calculate gradients w.r.t. real images for gradient penalty

326if(idx+1)%self.lazy\_gradient\_penalty\_interval==0:327real\_images.requires\_grad\_()

#

Discriminator classification for real images

329real\_output=self.discriminator(real\_images)

#

Get discriminator loss

332real\_loss,fake\_loss=self.discriminator\_loss(real\_output,fake\_output)333disc\_loss=real\_loss+fake\_loss

#

Add gradient penalty

336if(idx+1)%self.lazy\_gradient\_penalty\_interval==0:

#

Calculate and log gradient penalty

338gp=self.gradient\_penalty(real\_images,real\_output)339tracker.add('loss.gp',gp)

#

Multiply by coefficient and add gradient penalty

341disc\_loss=disc\_loss+0.5\*self.gradient\_penalty\_coefficient\*gp\*self.lazy\_gradient\_penalty\_interval

#

Compute gradients

344disc\_loss.backward()

#

Log discriminator loss

347tracker.add('loss.discriminator',disc\_loss)348349if(idx+1)%self.log\_generated\_interval==0:

#

Log discriminator model parameters occasionally

351tracker.add('discriminator',self.discriminator)

#

Clip gradients for stabilization

354torch.nn.utils.clip\_grad\_norm\_(self.discriminator.parameters(),max\_norm=1.0)

#

Take optimizer step

356self.discriminator\_optimizer.step()

#

Train the generator

359withmonit.section('Generator'):

#

Reset gradients

361self.generator\_optimizer.zero\_grad()362self.mapping\_network\_optimizer.zero\_grad()

#

Accumulate gradients for gradient_accumulate_steps

365foriinrange(self.gradient\_accumulate\_steps):

#

Sample images from generator

367generated\_images,w=self.generate\_images(self.batch\_size)

#

Discriminator classification for generated images

369fake\_output=self.discriminator(generated\_images)

#

Get generator loss

372gen\_loss=self.generator\_loss(fake\_output)

#

Add path length penalty

375ifidx\>self.lazy\_path\_penalty\_afterand(idx+1)%self.lazy\_path\_penalty\_interval==0:

#

Calculate path length penalty

377plp=self.path\_length\_penalty(w,generated\_images)

#

Ignore if nan

379ifnottorch.isnan(plp):380tracker.add('loss.plp',plp)381gen\_loss=gen\_loss+plp

#

Calculate gradients

384gen\_loss.backward()

#

Log generator loss

387tracker.add('loss.generator',gen\_loss)388389if(idx+1)%self.log\_generated\_interval==0:

#

Log discriminator model parameters occasionally

391tracker.add('generator',self.generator)392tracker.add('mapping\_network',self.mapping\_network)

#

Clip gradients for stabilization

395torch.nn.utils.clip\_grad\_norm\_(self.generator.parameters(),max\_norm=1.0)396torch.nn.utils.clip\_grad\_norm\_(self.mapping\_network.parameters(),max\_norm=1.0)

#

Take optimizer step

399self.generator\_optimizer.step()400self.mapping\_network\_optimizer.step()

#

Log generated images

403if(idx+1)%self.log\_generated\_interval==0:404tracker.add('generated',torch.cat([generated\_images[:6],real\_images[:3]],dim=0))

#

Save model checkpoints

406if(idx+1)%self.save\_checkpoint\_interval==0:

#

Save checkpoint

408pass

#

Flush tracker

411tracker.save()

#

Train model

413deftrain(self):

#

Loop for training_steps

419foriinmonit.loop(self.training\_steps):

#

Take a training step

421self.step(i)

#

423if(i+1)%self.log\_generated\_interval==0:424tracker.new\_line()

#

Train StyleGAN2

427defmain():

#

Create an experiment

433experiment.create(name='stylegan2')

#

Create configurations object

435configs=Configs()

#

Set configurations and override some

438experiment.configs(configs,{439'device.cuda\_device':0,440'image\_size':64,441'log\_generated\_interval':200442})

#

Initialize

445configs.init()

#

Set models for saving and loading

447experiment.add\_pytorch\_models(mapping\_network=configs.mapping\_network,448generator=configs.generator,449discriminator=configs.discriminator)

#

Start the experiment

452withexperiment.start():

#

Run the training loop

454configs.train()

#

458if\_\_name\_\_=='\_\_main\_\_':459main()

labml.ai