docs/gan/cycle_gan/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/gan/cycle_gan/ init.py)
This is a PyTorch implementation/tutorial of the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.
I've taken pieces of code from eriklindernoren/PyTorch-GAN. It is a very good resource if you want to checkout other GAN variations too.
Cycle GAN does image-to-image translation. It trains a model to translate an image from given distribution to another, say, images of class A and B. Images of a certain distribution could be things like images of a certain style, or nature. The models do not need paired images between A and B. Just a set of images of each class is enough. This works very well on changing between image styles, lighting changes, pattern changes, etc. For example, changing summer to winter, painting style to photos, and horses to zebras.
Cycle GAN trains two generator models and two discriminator models. One generator translates images from A to B and the other from B to A. The discriminators test whether the generated images look real.
This file contains the model code as well as the training code. We also have a Google Colab notebook.
35importitertools36importrandom37importzipfile38fromtypingimportTuple3940importtorch41importtorch.nnasnn42importtorchvision.transformsastransforms43fromPILimportImage44fromtorch.utils.dataimportDataLoader,Dataset45fromtorchvision.transformsimportInterpolationMode46fromtorchvision.utilsimportmake\_grid4748fromlabmlimportlab,tracker,experiment,monit49fromlabml.configsimportBaseConfigs50fromlabml.utils.downloadimportdownload\_file51fromlabml.utils.pytorchimportget\_modules52fromlabml\_nn.helpers.deviceimportDeviceConfigs
The generator is a residual network.
55classGeneratorResNet(nn.Module):
60def\_\_init\_\_(self,input\_channels:int,n\_residual\_blocks:int):61super().\_\_init\_\_()
This first block runs a 7×7 convolution and maps the image to a feature map. The output feature map has the same height and width because we have a padding of 3. Reflection padding is used because it gives better image quality at edges.
inplace=True in ReLU saves a little bit of memory.
69out\_features=6470layers=[71nn.Conv2d(input\_channels,out\_features,kernel\_size=7,padding=3,padding\_mode='reflect'),72nn.InstanceNorm2d(out\_features),73nn.ReLU(inplace=True),74]75in\_features=out\_features
We down-sample with two 3×3 convolutions with stride of 2
79for\_inrange(2):80out\_features\*=281layers+=[82nn.Conv2d(in\_features,out\_features,kernel\_size=3,stride=2,padding=1),83nn.InstanceNorm2d(out\_features),84nn.ReLU(inplace=True),85]86in\_features=out\_features
We take this through n_residual_blocks . This module is defined below.
90for\_inrange(n\_residual\_blocks):91layers+=[ResidualBlock(out\_features)]
Then the resulting feature map is up-sampled to match the original image height and width.
95for\_inrange(2):96out\_features//=297layers+=[98nn.Upsample(scale\_factor=2),99nn.Conv2d(in\_features,out\_features,kernel\_size=3,stride=1,padding=1),100nn.InstanceNorm2d(out\_features),101nn.ReLU(inplace=True),102]103in\_features=out\_features
Finally we map the feature map to an RGB image
106layers+=[nn.Conv2d(out\_features,input\_channels,7,padding=3,padding\_mode='reflect'),nn.Tanh()]
Create a sequential module with the layers
109self.layers=nn.Sequential(\*layers)
Initialize weights to N(0,0.2)
112self.apply(weights\_init\_normal)
114defforward(self,x):115returnself.layers(x)
This is the residual block, with two convolution layers.
118classResidualBlock(nn.Module):
123def\_\_init\_\_(self,in\_features:int):124super().\_\_init\_\_()125self.block=nn.Sequential(126nn.Conv2d(in\_features,in\_features,kernel\_size=3,padding=1,padding\_mode='reflect'),127nn.InstanceNorm2d(in\_features),128nn.ReLU(inplace=True),129nn.Conv2d(in\_features,in\_features,kernel\_size=3,padding=1,padding\_mode='reflect'),130nn.InstanceNorm2d(in\_features),131nn.ReLU(inplace=True),132)
134defforward(self,x:torch.Tensor):135returnx+self.block(x)
This is the discriminator.
138classDiscriminator(nn.Module):
143def\_\_init\_\_(self,input\_shape:Tuple[int,int,int]):144super().\_\_init\_\_()145channels,height,width=input\_shape
Output of the discriminator is also a map of probabilities, whether each region of the image is real or generated
149self.output\_shape=(1,height//2\*\*4,width//2\*\*4)150151self.layers=nn.Sequential(
Each of these blocks will shrink the height and width by a factor of 2
153DiscriminatorBlock(channels,64,normalize=False),154DiscriminatorBlock(64,128),155DiscriminatorBlock(128,256),156DiscriminatorBlock(256,512),
Zero pad on top and left to keep the output height and width same with the 4×4 kernel
159nn.ZeroPad2d((1,0,1,0)),160nn.Conv2d(512,1,kernel\_size=4,padding=1)161)
Initialize weights to N(0,0.2)
164self.apply(weights\_init\_normal)
166defforward(self,img):167returnself.layers(img)
This is the discriminator block module. It does a convolution, an optional normalization, and a leaky ReLU.
It shrinks the height and width of the input feature map by half.
170classDiscriminatorBlock(nn.Module):
178def\_\_init\_\_(self,in\_filters:int,out\_filters:int,normalize:bool=True):179super().\_\_init\_\_()180layers=[nn.Conv2d(in\_filters,out\_filters,kernel\_size=4,stride=2,padding=1)]181ifnormalize:182layers.append(nn.InstanceNorm2d(out\_filters))183layers.append(nn.LeakyReLU(0.2,inplace=True))184self.layers=nn.Sequential(\*layers)
186defforward(self,x:torch.Tensor):187returnself.layers(x)
Initialize convolution layer weights to N(0,0.2)
190defweights\_init\_normal(m):
194classname=m.\_\_class\_\_.\_\_name\_\_195ifclassname.find("Conv")!=-1:196torch.nn.init.normal\_(m.weight.data,0.0,0.02)
Load an image and change to RGB if in grey-scale.
199defload\_image(path:str):
203image=Image.open(path)204ifimage.mode!='RGB':205image=Image.new("RGB",image.size).paste(image)206207returnimage
210classImageDataset(Dataset):
215@staticmethod216defdownload(dataset\_name:str):
URL
221url=f'https://people.eecs.berkeley.edu/~taesung\_park/CycleGAN/datasets/{dataset\_name}.zip'
Download folder
223root=lab.get\_data\_path()/'cycle\_gan'224ifnotroot.exists():225root.mkdir(parents=True)
Download destination
227archive=root/f'{dataset\_name}.zip'
Download file (generally ~100MB)
229download\_file(url,archive)
Extract the archive
231withzipfile.ZipFile(archive,'r')asf:232f.extractall(root)
dataset_name is the name of the datasettransforms_ is the set of image transformsmode is either train or test234def\_\_init\_\_(self,dataset\_name:str,transforms\_,mode:str):
Dataset path
243root=lab.get\_data\_path()/'cycle\_gan'/dataset\_name
Download if missing
245ifnotroot.exists():246self.download(dataset\_name)
Image transforms
249self.transform=transforms.Compose(transforms\_)
Get image paths
252path\_a=root/f'{mode}A'253path\_b=root/f'{mode}B'254self.files\_a=sorted(str(f)forfinpath\_a.iterdir())255self.files\_b=sorted(str(f)forfinpath\_b.iterdir())
257def\_\_getitem\_\_(self,index):
Return a pair of images. These pairs get batched together, and they do not act like pairs in training. So it is kind of ok that we always keep giving the same pair.
261return{"x":self.transform(load\_image(self.files\_a[index%len(self.files\_a)])),262"y":self.transform(load\_image(self.files\_b[index%len(self.files\_b)]))}
264def\_\_len\_\_(self):
Number of images in the dataset
266returnmax(len(self.files\_a),len(self.files\_b))
Replay buffer is used to train the discriminator. Generated images are added to the replay buffer and sampled from it.
The replay buffer returns the newly added image with a probability of 0.5. Otherwise, it sends an older generated image and replaces the older image with the newly generated image.
This is done to reduce model oscillation.
269classReplayBuffer:
283def\_\_init\_\_(self,max\_size:int=50):284self.max\_size=max\_size285self.data=[]
Add/retrieve an image
287defpush\_and\_pop(self,data:torch.Tensor):
289data=data.detach()290res=[]291forelementindata:292iflen(self.data)\<self.max\_size:293self.data.append(element)294res.append(element)295else:296ifrandom.uniform(0,1)\>0.5:297i=random.randint(0,self.max\_size-1)298res.append(self.data[i].clone())299self.data[i]=element300else:301res.append(element)302returntorch.stack(res)
305classConfigs(BaseConfigs):
DeviceConfigs will pick a GPU if available
309device:torch.device=DeviceConfigs()
Hyper-parameters
312epochs:int=200313dataset\_name:str='monet2photo'314batch\_size:int=1315316data\_loader\_workers=8317318learning\_rate=0.0002319adam\_betas=(0.5,0.999)320decay\_start=100
The paper suggests using a least-squares loss instead of negative log-likelihood, at it is found to be more stable.
324gan\_loss=torch.nn.MSELoss()
L1 loss is used for cycle loss and identity loss
327cycle\_loss=torch.nn.L1Loss()328identity\_loss=torch.nn.L1Loss()
Image dimensions
331img\_height=256332img\_width=256333img\_channels=3
Number of residual blocks in the generator
336n\_residual\_blocks=9
Loss coefficients
339cyclic\_loss\_coefficient=10.0340identity\_loss\_coefficient=5.341342sample\_interval=500
Models
345generator\_xy:GeneratorResNet346generator\_yx:GeneratorResNet347discriminator\_x:Discriminator348discriminator\_y:Discriminator
Optimizers
351generator\_optimizer:torch.optim.Adam352discriminator\_optimizer:torch.optim.Adam
Learning rate schedules
355generator\_lr\_scheduler:torch.optim.lr\_scheduler.LambdaLR356discriminator\_lr\_scheduler:torch.optim.lr\_scheduler.LambdaLR
Data loaders
359dataloader:DataLoader360valid\_dataloader:DataLoader
Generate samples from test set and save them
362defsample\_images(self,n:int):
364batch=next(iter(self.valid\_dataloader))365self.generator\_xy.eval()366self.generator\_yx.eval()367withtorch.no\_grad():368data\_x,data\_y=batch['x'].to(self.generator\_xy.device),batch['y'].to(self.generator\_yx.device)369gen\_y=self.generator\_xy(data\_x)370gen\_x=self.generator\_yx(data\_y)
Arrange images along x-axis
373data\_x=make\_grid(data\_x,nrow=5,normalize=True)374data\_y=make\_grid(data\_y,nrow=5,normalize=True)375gen\_x=make\_grid(gen\_x,nrow=5,normalize=True)376gen\_y=make\_grid(gen\_y,nrow=5,normalize=True)
Arrange images along y-axis
379image\_grid=torch.cat((data\_x,gen\_y,data\_y,gen\_x),1)
Show samples
382plot\_image(image\_grid)
384definitialize(self):
388input\_shape=(self.img\_channels,self.img\_height,self.img\_width)
Create the models
391self.generator\_xy=GeneratorResNet(self.img\_channels,self.n\_residual\_blocks).to(self.device)392self.generator\_yx=GeneratorResNet(self.img\_channels,self.n\_residual\_blocks).to(self.device)393self.discriminator\_x=Discriminator(input\_shape).to(self.device)394self.discriminator\_y=Discriminator(input\_shape).to(self.device)
Create the optmizers
397self.generator\_optimizer=torch.optim.Adam(398itertools.chain(self.generator\_xy.parameters(),self.generator\_yx.parameters()),399lr=self.learning\_rate,betas=self.adam\_betas)400self.discriminator\_optimizer=torch.optim.Adam(401itertools.chain(self.discriminator\_x.parameters(),self.discriminator\_y.parameters()),402lr=self.learning\_rate,betas=self.adam\_betas)
Create the learning rate schedules. The learning rate stars flat until decay_start epochs, and then linearly reduce to 0 at end of training.
407decay\_epochs=self.epochs-self.decay\_start408self.generator\_lr\_scheduler=torch.optim.lr\_scheduler.LambdaLR(409self.generator\_optimizer,lr\_lambda=lambdae:1.0-max(0,e-self.decay\_start)/decay\_epochs)410self.discriminator\_lr\_scheduler=torch.optim.lr\_scheduler.LambdaLR(411self.discriminator\_optimizer,lr\_lambda=lambdae:1.0-max(0,e-self.decay\_start)/decay\_epochs)
Image transformations
414transforms\_=[415transforms.Resize(int(self.img\_height\*1.12),InterpolationMode.BICUBIC),416transforms.RandomCrop((self.img\_height,self.img\_width)),417transforms.RandomHorizontalFlip(),418transforms.ToTensor(),419transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),420]
Training data loader
423self.dataloader=DataLoader(424ImageDataset(self.dataset\_name,transforms\_,'train'),425batch\_size=self.batch\_size,426shuffle=True,427num\_workers=self.data\_loader\_workers,428)
Validation data loader
431self.valid\_dataloader=DataLoader(432ImageDataset(self.dataset\_name,transforms\_,"test"),433batch\_size=5,434shuffle=True,435num\_workers=self.data\_loader\_workers,436)
We aim to solve: G∗,F∗=argG,FminDX,DYmaxL(G,F,DX,DY)
where, G translates images from X→Y, F translates images from Y→X, DX tests if images are from X space, DY tests if images are from Y space, and
L(G,F,DX,DY)LGAN(G,F,DY,X,Y)Lcyc(G,F)Lidentity(G,F)=LGAN(G,DY,X,Y)+LGAN(F,DX,Y,X)+λ1Lcyc(G,F)+λ2Lidentity(G,F)=Ey∼pdata(y)[logDY(y)]+Ex∼pdata(x)[log(1−DY(G(x)))]+Ex∼pdata(x)[logDX(x)]+Ey∼pdata(y)[log(1−DX(F(y)))]=Ex∼pdata(x)[∥F(G(x))−x∥1]+Ey∼pdata(y)[∥G(F(y))−y∥1]=Ex∼pdata(x)[∥F(x)−x∥1]+Ey∼pdata(y)[∥G(y)−y∥1]
LGAN is the generative adversarial loss from the original GAN paper.
Lcyc is the cyclic loss, where we try to get F(G(x)) to be similar to x, and G(F(y)) to be similar to y. Basically if the two generators (transformations) are applied in series it should give back the original image. This is the main contribution of this paper. It trains the generators to generate an image of the other distribution that is similar to the original image. Without this loss G(x) could generate anything that's from the distribution of Y. Now it needs to generate something from the distribution of Y but still has properties of x, so that F(G(x) can re-generate something like x.
Lcyc is the identity loss. This was used to encourage the mapping to preserve color composition between the input and the output.
To solve G∗,F∗, discriminators DX and DY should ascend on the gradient,
∇θDX,DYm1i=1∑m[logDY(y(i))+log(1−DY(G(x(i))))+logDX(x(i))+log(1−DX(F(y(i))))]
That is descend on negative log-likelihood loss.
In order to stabilize the training the negative log- likelihood objective was replaced by a least-squared loss - the least-squared error of discriminator, labelling real images with 1, and generated images with 0. So we want to descend on the gradient,
∇θDX,DYm1i=1∑m[(DY(y(i))−1)2+DY(G(x(i)))2+(DX(x(i))−1)2+DX(F(y(i)))2]
We use least-squares for generators also. The generators should descend on the gradient,
∇θF,Gm1i=1∑m[(DY(G(x(i)))−1)2+(DX(F(y(i)))−1)2+Lcyc(G,F)+Lidentity(G,F)]
We use generator_xy for G and generator_yx for F. We use discriminator_x for DX and discriminator_y for DY.
438defrun(self):
Replay buffers to keep generated samples
540gen\_x\_buffer=ReplayBuffer()541gen\_y\_buffer=ReplayBuffer()
Loop through epochs
544forepochinmonit.loop(self.epochs):
Loop through the dataset
546fori,batchinmonit.enum('Train',self.dataloader):
Move images to the device
548data\_x,data\_y=batch['x'].to(self.device),batch['y'].to(self.device)
true labels equal to 1
551true\_labels=torch.ones(data\_x.size(0),\*self.discriminator\_x.output\_shape,552device=self.device,requires\_grad=False)
false labels equal to 0
554false\_labels=torch.zeros(data\_x.size(0),\*self.discriminator\_x.output\_shape,555device=self.device,requires\_grad=False)
Train the generators. This returns the generated images.
559gen\_x,gen\_y=self.optimize\_generators(data\_x,data\_y,true\_labels)
Train discriminators
562self.optimize\_discriminator(data\_x,data\_y,563gen\_x\_buffer.push\_and\_pop(gen\_x),gen\_y\_buffer.push\_and\_pop(gen\_y),564true\_labels,false\_labels)
Save training statistics and increment the global step counter
567tracker.save()568tracker.add\_global\_step(max(len(data\_x),len(data\_y)))
Save images at intervals
571batches\_done=epoch\*len(self.dataloader)+i572ifbatches\_done%self.sample\_interval==0:
Sample images
574self.sample\_images(batches\_done)
Update learning rates
577self.generator\_lr\_scheduler.step()578self.discriminator\_lr\_scheduler.step()
New line
580tracker.new\_line()
582defoptimize\_generators(self,data\_x:torch.Tensor,data\_y:torch.Tensor,true\_labels:torch.Tensor):
Change to training mode
588self.generator\_xy.train()589self.generator\_yx.train()
Identity loss ∥F(G(x(i)))−x(i)∥1 ∥G(F(y(i)))−y(i)∥1
594loss\_identity=(self.identity\_loss(self.generator\_yx(data\_x),data\_x)+595self.identity\_loss(self.generator\_xy(data\_y),data\_y))
Generate images G(x) and F(y)
598gen\_y=self.generator\_xy(data\_x)599gen\_x=self.generator\_yx(data\_y)
GAN loss (DY(G(x(i)))−1)2+(DX(F(y(i)))−1)2
604loss\_gan=(self.gan\_loss(self.discriminator\_y(gen\_y),true\_labels)+605self.gan\_loss(self.discriminator\_x(gen\_x),true\_labels))
Cycle loss ∥F(G(x(i)))−x(i)∥1+∥G(F(y(i)))−y(i)∥1
612loss\_cycle=(self.cycle\_loss(self.generator\_yx(gen\_y),data\_x)+613self.cycle\_loss(self.generator\_xy(gen\_x),data\_y))
Total loss
616loss\_generator=(loss\_gan+617self.cyclic\_loss\_coefficient\*loss\_cycle+618self.identity\_loss\_coefficient\*loss\_identity)
Take a step in the optimizer
621self.generator\_optimizer.zero\_grad()622loss\_generator.backward()623self.generator\_optimizer.step()
Log losses
626tracker.add({'loss.generator':loss\_generator,627'loss.generator.cycle':loss\_cycle,628'loss.generator.gan':loss\_gan,629'loss.generator.identity':loss\_identity})
Return generated images
632returngen\_x,gen\_y
634defoptimize\_discriminator(self,data\_x:torch.Tensor,data\_y:torch.Tensor,635gen\_x:torch.Tensor,gen\_y:torch.Tensor,636true\_labels:torch.Tensor,false\_labels:torch.Tensor):
GAN Loss
(DY(y(i))−1)2+DY(G(x(i)))2+(DX(x(i))−1)2+DX(F(y(i)))2
649loss\_discriminator=(self.gan\_loss(self.discriminator\_x(data\_x),true\_labels)+650self.gan\_loss(self.discriminator\_x(gen\_x),false\_labels)+651self.gan\_loss(self.discriminator\_y(data\_y),true\_labels)+652self.gan\_loss(self.discriminator\_y(gen\_y),false\_labels))
Take a step in the optimizer
655self.discriminator\_optimizer.zero\_grad()656loss\_discriminator.backward()657self.discriminator\_optimizer.step()
Log losses
660tracker.add({'loss.discriminator':loss\_discriminator})
663deftrain():
Create configurations
668conf=Configs()
Create an experiment
670experiment.create(name='cycle\_gan')
Calculate configurations. It will calculate conf.run and all other configs required by it.
673experiment.configs(conf,{'dataset\_name':'summer2winter\_yosemite'})674conf.initialize()
Register models for saving and loading. get_modules gives a dictionary of nn.Modules in conf . You can also specify a custom dictionary of models.
679experiment.add\_pytorch\_models(get\_modules(conf))
Start and watch the experiment
681withexperiment.start():
Run the training
683conf.run()
686defplot\_image(img:torch.Tensor):
690frommatplotlibimportpyplotasplt
Move tensor to CPU
693img=img.cpu()
Get min and max values of the image for normalization
695img\_min,img\_max=img.min(),img.max()
Scale image values to be 0...1
697img=(img-img\_min)/(img\_max-img\_min+1e-5)
We have to change the order of dimensions to HWC.
699img=img.permute(1,2,0)
Show Image
701plt.imshow(img)
We don't need axes
703plt.axis('off')
Display
705plt.show()
708defevaluate():
Set the run UUID from the training run
713trained\_run\_uuid='f73c1164184711eb9190b74249275441'
Create configs object
715conf=Configs()
Create experiment
717experiment.create(name='cycle\_gan\_inference')
Load hyper parameters set for training
719conf\_dict=experiment.load\_configs(trained\_run\_uuid)
Calculate configurations. We specify the generators 'generator_xy', 'generator_yx' so that it only loads those and their dependencies. Configs like device and img_channels will be calculated, since these are required by generator_xy and generator_yx .
If you want other parameters like dataset_name you should specify them here. If you specify nothing, all the configurations will be calculated, including data loaders. Calculation of configurations and their dependencies will happen when you call experiment.start
728experiment.configs(conf,conf\_dict)729conf.initialize()
Register models for saving and loading. get_modules gives a dictionary of nn.Modules in conf . You can also specify a custom dictionary of models.
734experiment.add\_pytorch\_models(get\_modules(conf))
Specify which run to load from. Loading will actually happen when you call experiment.start
737experiment.load(trained\_run\_uuid)
Start the experiment
740withexperiment.start():
Image transformations
742transforms\_=[743transforms.ToTensor(),744transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),745]
Load your own data. Here we try the test set. I was trying with Yosemite photos, they look awesome. You can use conf.dataset_name , if you specified dataset_name as something you wanted to be calculated in the call to experiment.configs
751dataset=ImageDataset(conf.dataset\_name,transforms\_,'train')
Get an image from dataset
753x\_image=dataset[10]['x']
Display the image
755plot\_image(x\_image)
Evaluation mode
758conf.generator\_xy.eval()759conf.generator\_yx.eval()
We don't need gradients
762withtorch.no\_grad():
Add batch dimension and move to the device we use
764data=x\_image.unsqueeze(0).to(conf.device)765generated\_y=conf.generator\_xy(data)
Display the generated image.
768plot\_image(generated\_y[0].cpu())769770771if\_\_name\_\_=='\_\_main\_\_':772train()
evaluate()