docs/gan/stylegan/experiment.html
This is the training code for StyleGAN 2 model.
These are 64×64 images generated after training for about 80K steps.
Our implementation is a minimalistic StyleGAN 2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.
Without DDP (distributed data parallel) and multi-gpu training it will not be possible to train the model for large resolutions (128+). If you want training code with fp16 and DDP take a look at lucidrains/stylegan2-pytorch.
We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan folder.
31importmath32frompathlibimportPath33fromtypingimportIterator,Tuple3435importtorchvision36fromPILimportImage3738importtorch39importtorch.utils.data40fromlabmlimporttracker,lab,monit,experiment41fromlabml.configsimportBaseConfigs42fromlabml\_nn.gan.styleganimportDiscriminator,Generator,MappingNetwork,GradientPenalty,PathLengthPenalty43fromlabml\_nn.gan.wassersteinimportDiscriminatorLoss,GeneratorLoss44fromlabml\_nn.helpers.deviceimportDeviceConfigs45fromlabml\_nn.helpers.trainerimportModeState46fromlabml\_nn.utilsimportcycle\_dataloader
This loads the training dataset and resize it to the give image size.
49classDataset(torch.utils.data.Dataset):
path path to the folder containing the imagesimage_size size of the image56def\_\_init\_\_(self,path:str,image\_size:int):
61super().\_\_init\_\_()
Get the paths of all jpg files
64self.paths=[pforpinPath(path).glob(f'\*\*/\*.jpg')]
Transformation
67self.transform=torchvision.transforms.Compose([
Resize the image
69torchvision.transforms.Resize(image\_size),
Convert to PyTorch tensor
71torchvision.transforms.ToTensor(),72])
Number of images
74def\_\_len\_\_(self):
76returnlen(self.paths)
Get the the index -th image
78def\_\_getitem\_\_(self,index):
80path=self.paths[index]81img=Image.open(path)82returnself.transform(img)
85classConfigs(BaseConfigs):
Device to train the model on. DeviceConfigs picks up an available CUDA device or defaults to CPU.
93device:torch.device=DeviceConfigs()
96discriminator:Discriminator
98generator:Generator
100mapping\_network:MappingNetwork
Discriminator and generator loss functions. We use Wasserstein loss
104discriminator\_loss:DiscriminatorLoss105generator\_loss:GeneratorLoss
Optimizers
108generator\_optimizer:torch.optim.Adam109discriminator\_optimizer:torch.optim.Adam110mapping\_network\_optimizer:torch.optim.Adam
Gradient Penalty Regularization Loss
113gradient\_penalty=GradientPenalty()
Gradient penalty coefficient γ
115gradient\_penalty\_coefficient:float=10.
118path\_length\_penalty:PathLengthPenalty
Data loader
121loader:Iterator
Batch size
124batch\_size:int=32
Dimensionality of z and w
126d\_latent:int=512
Height/width of the image
128image\_size:int=32
Number of layers in the mapping network
130mapping\_network\_layers:int=8
Generator & Discriminator learning rate
132learning\_rate:float=1e-3
Mapping network learning rate (100× lower than the others)
134mapping\_network\_learning\_rate:float=1e-5
Number of steps to accumulate gradients on. Use this to increase the effective batch size.
136gradient\_accumulate\_steps:int=1
β1 and β2 for Adam optimizer
138adam\_betas:Tuple[float,float]=(0.0,0.99)
Probability of mixing styles
140style\_mixing\_prob:float=0.9
Total number of training steps
143training\_steps:int=150\_000
Number of blocks in the generator (calculated based on image resolution)
146n\_gen\_blocks:int
Instead of calculating the regularization losses, the paper proposes lazy regularization where the regularization terms are calculated once in a while. This improves the training efficiency a lot.
The interval at which to compute gradient penalty
154lazy\_gradient\_penalty\_interval:int=4
Path length penalty calculation interval
156lazy\_path\_penalty\_interval:int=32
Skip calculating path length penalty during the initial phase of training
158lazy\_path\_penalty\_after:int=5\_000
How often to log generated images
161log\_generated\_interval:int=500
How often to save model checkpoints
163save\_checkpoint\_interval:int=2\_000
Training mode state for logging activations
166mode:ModeState
We trained this on CelebA-HQ dataset. You can find the download instruction in this discussion on fast.ai. Save the images inside data/stylegan folder.
173dataset\_path:str=str(lab.get\_data\_path()/'stylegan2')
175definit(self):
Create dataset
180dataset=Dataset(self.dataset\_path,self.image\_size)
Create data loader
182dataloader=torch.utils.data.DataLoader(dataset,batch\_size=self.batch\_size,num\_workers=8,183shuffle=True,drop\_last=True,pin\_memory=True)
Continuous cyclic loader
185self.loader=cycle\_dataloader(dataloader)
log2 of image resolution
188log\_resolution=int(math.log2(self.image\_size))
Create discriminator and generator
191self.discriminator=Discriminator(log\_resolution).to(self.device)192self.generator=Generator(log\_resolution,self.d\_latent).to(self.device)
Get number of generator blocks for creating style and noise inputs
194self.n\_gen\_blocks=self.generator.n\_blocks
Create mapping network
196self.mapping\_network=MappingNetwork(self.d\_latent,self.mapping\_network\_layers).to(self.device)
Create path length penalty loss
198self.path\_length\_penalty=PathLengthPenalty(0.99).to(self.device)
Discriminator and generator losses
201self.discriminator\_loss=DiscriminatorLoss().to(self.device)202self.generator\_loss=GeneratorLoss().to(self.device)
Create optimizers
205self.discriminator\_optimizer=torch.optim.Adam(206self.discriminator.parameters(),207lr=self.learning\_rate,betas=self.adam\_betas208)209self.generator\_optimizer=torch.optim.Adam(210self.generator.parameters(),211lr=self.learning\_rate,betas=self.adam\_betas212)213self.mapping\_network\_optimizer=torch.optim.Adam(214self.mapping\_network.parameters(),215lr=self.mapping\_network\_learning\_rate,betas=self.adam\_betas216)
Set tracker configurations
219tracker.set\_image("generated",True)
This samples z randomly and get w from the mapping network.
We also apply style mixing sometimes where we generate two latent variables z1 and z2 and get corresponding w1 and w2. Then we randomly sample a cross-over point and apply w1 to the generator blocks before the cross-over point and w2 to the blocks after.
221defget\_w(self,batch\_size:int):
Mix styles
235iftorch.rand(()).item()\<self.style\_mixing\_prob:
Random cross-over point
237cross\_over\_point=int(torch.rand(()).item()\*self.n\_gen\_blocks)
Sample z1 and z2
239z2=torch.randn(batch\_size,self.d\_latent).to(self.device)240z1=torch.randn(batch\_size,self.d\_latent).to(self.device)
Get w1 and w2
242w1=self.mapping\_network(z1)243w2=self.mapping\_network(z2)
Expand w1 and w2 for the generator blocks and concatenate
245w1=w1[None,:,:].expand(cross\_over\_point,-1,-1)246w2=w2[None,:,:].expand(self.n\_gen\_blocks-cross\_over\_point,-1,-1)247returntorch.cat((w1,w2),dim=0)
Without mixing
249else:
Sample z and z
251z=torch.randn(batch\_size,self.d\_latent).to(self.device)
Get w and w
253w=self.mapping\_network(z)
Expand w for the generator blocks
255returnw[None,:,:].expand(self.n\_gen\_blocks,-1,-1)
This generates noise for each generator block
257defget\_noise(self,batch\_size:int):
List to store noise
264noise=[]
Noise resolution starts from 4
266resolution=4
Generate noise for each generator block
269foriinrange(self.n\_gen\_blocks):
The first block has only one 3×3 convolution
271ifi==0:272n1=None
Generate noise to add after the first convolution layer
274else:275n1=torch.randn(batch\_size,1,resolution,resolution,device=self.device)
Generate noise to add after the second convolution layer
277n2=torch.randn(batch\_size,1,resolution,resolution,device=self.device)
Add noise tensors to the list
280noise.append((n1,n2))
Next block has 2× resolution
283resolution\*=2
Return noise tensors
286returnnoise
This generate images using the generator
288defgenerate\_images(self,batch\_size:int):
Get w
296w=self.get\_w(batch\_size)
Get noise
298noise=self.get\_noise(batch\_size)
Generate images
301images=self.generator(w,noise)
Return images and w
304returnimages,w
306defstep(self,idx:int):
Train the discriminator
312withmonit.section('Discriminator'):
Reset gradients
314self.discriminator\_optimizer.zero\_grad()
Accumulate gradients for gradient_accumulate_steps
317foriinrange(self.gradient\_accumulate\_steps):
Sample images from generator
319generated\_images,\_=self.generate\_images(self.batch\_size)
Discriminator classification for generated images
321fake\_output=self.discriminator(generated\_images.detach())
Get real images from the data loader
324real\_images=next(self.loader).to(self.device)
We need to calculate gradients w.r.t. real images for gradient penalty
326if(idx+1)%self.lazy\_gradient\_penalty\_interval==0:327real\_images.requires\_grad\_()
Discriminator classification for real images
329real\_output=self.discriminator(real\_images)
Get discriminator loss
332real\_loss,fake\_loss=self.discriminator\_loss(real\_output,fake\_output)333disc\_loss=real\_loss+fake\_loss
Add gradient penalty
336if(idx+1)%self.lazy\_gradient\_penalty\_interval==0:
Calculate and log gradient penalty
338gp=self.gradient\_penalty(real\_images,real\_output)339tracker.add('loss.gp',gp)
Multiply by coefficient and add gradient penalty
341disc\_loss=disc\_loss+0.5\*self.gradient\_penalty\_coefficient\*gp\*self.lazy\_gradient\_penalty\_interval
Compute gradients
344disc\_loss.backward()
Log discriminator loss
347tracker.add('loss.discriminator',disc\_loss)348349if(idx+1)%self.log\_generated\_interval==0:
Log discriminator model parameters occasionally
351tracker.add('discriminator',self.discriminator)
Clip gradients for stabilization
354torch.nn.utils.clip\_grad\_norm\_(self.discriminator.parameters(),max\_norm=1.0)
Take optimizer step
356self.discriminator\_optimizer.step()
Train the generator
359withmonit.section('Generator'):
Reset gradients
361self.generator\_optimizer.zero\_grad()362self.mapping\_network\_optimizer.zero\_grad()
Accumulate gradients for gradient_accumulate_steps
365foriinrange(self.gradient\_accumulate\_steps):
Sample images from generator
367generated\_images,w=self.generate\_images(self.batch\_size)
Discriminator classification for generated images
369fake\_output=self.discriminator(generated\_images)
Get generator loss
372gen\_loss=self.generator\_loss(fake\_output)
Add path length penalty
375ifidx\>self.lazy\_path\_penalty\_afterand(idx+1)%self.lazy\_path\_penalty\_interval==0:
Calculate path length penalty
377plp=self.path\_length\_penalty(w,generated\_images)
Ignore if nan
379ifnottorch.isnan(plp):380tracker.add('loss.plp',plp)381gen\_loss=gen\_loss+plp
Calculate gradients
384gen\_loss.backward()
Log generator loss
387tracker.add('loss.generator',gen\_loss)388389if(idx+1)%self.log\_generated\_interval==0:
Log discriminator model parameters occasionally
391tracker.add('generator',self.generator)392tracker.add('mapping\_network',self.mapping\_network)
Clip gradients for stabilization
395torch.nn.utils.clip\_grad\_norm\_(self.generator.parameters(),max\_norm=1.0)396torch.nn.utils.clip\_grad\_norm\_(self.mapping\_network.parameters(),max\_norm=1.0)
Take optimizer step
399self.generator\_optimizer.step()400self.mapping\_network\_optimizer.step()
Log generated images
403if(idx+1)%self.log\_generated\_interval==0:404tracker.add('generated',torch.cat([generated\_images[:6],real\_images[:3]],dim=0))
Save model checkpoints
406if(idx+1)%self.save\_checkpoint\_interval==0:
Save checkpoint
408pass
Flush tracker
411tracker.save()
413deftrain(self):
Loop for training_steps
419foriinmonit.loop(self.training\_steps):
Take a training step
421self.step(i)
423if(i+1)%self.log\_generated\_interval==0:424tracker.new\_line()
427defmain():
Create an experiment
433experiment.create(name='stylegan2')
Create configurations object
435configs=Configs()
Set configurations and override some
438experiment.configs(configs,{439'device.cuda\_device':0,440'image\_size':64,441'log\_generated\_interval':200442})
Initialize
445configs.init()
Set models for saving and loading
447experiment.add\_pytorch\_models(mapping\_network=configs.mapping\_network,448generator=configs.generator,449discriminator=configs.discriminator)
Start the experiment
452withexperiment.start():
Run the training loop
454configs.train()
458if\_\_name\_\_=='\_\_main\_\_':459main()