Back to Annotated Deep Learning Paper Implementations

Cycle GAN

docs/gan/cycle_gan/index.html

latest25.5 KB
Original Source

homegancycle_gan

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/gan/cycle_gan/ init.py)

#

Cycle GAN

This is a PyTorch implementation/tutorial of the paper Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.

I've taken pieces of code from eriklindernoren/PyTorch-GAN. It is a very good resource if you want to checkout other GAN variations too.

Cycle GAN does image-to-image translation. It trains a model to translate an image from given distribution to another, say, images of class A and B. Images of a certain distribution could be things like images of a certain style, or nature. The models do not need paired images between A and B. Just a set of images of each class is enough. This works very well on changing between image styles, lighting changes, pattern changes, etc. For example, changing summer to winter, painting style to photos, and horses to zebras.

Cycle GAN trains two generator models and two discriminator models. One generator translates images from A to B and the other from B to A. The discriminators test whether the generated images look real.

This file contains the model code as well as the training code. We also have a Google Colab notebook.

35importitertools36importrandom37importzipfile38fromtypingimportTuple3940importtorch41importtorch.nnasnn42importtorchvision.transformsastransforms43fromPILimportImage44fromtorch.utils.dataimportDataLoader,Dataset45fromtorchvision.transformsimportInterpolationMode46fromtorchvision.utilsimportmake\_grid4748fromlabmlimportlab,tracker,experiment,monit49fromlabml.configsimportBaseConfigs50fromlabml.utils.downloadimportdownload\_file51fromlabml.utils.pytorchimportget\_modules52fromlabml\_nn.helpers.deviceimportDeviceConfigs

#

The generator is a residual network.

55classGeneratorResNet(nn.Module):

#

60def\_\_init\_\_(self,input\_channels:int,n\_residual\_blocks:int):61super().\_\_init\_\_()

#

This first block runs a 7×7 convolution and maps the image to a feature map. The output feature map has the same height and width because we have a padding of 3. Reflection padding is used because it gives better image quality at edges.

inplace=True in ReLU saves a little bit of memory.

69out\_features=6470layers=[71nn.Conv2d(input\_channels,out\_features,kernel\_size=7,padding=3,padding\_mode='reflect'),72nn.InstanceNorm2d(out\_features),73nn.ReLU(inplace=True),74]75in\_features=out\_features

#

We down-sample with two 3×3 convolutions with stride of 2

79for\_inrange(2):80out\_features\*=281layers+=[82nn.Conv2d(in\_features,out\_features,kernel\_size=3,stride=2,padding=1),83nn.InstanceNorm2d(out\_features),84nn.ReLU(inplace=True),85]86in\_features=out\_features

#

We take this through n_residual_blocks . This module is defined below.

90for\_inrange(n\_residual\_blocks):91layers+=[ResidualBlock(out\_features)]

#

Then the resulting feature map is up-sampled to match the original image height and width.

95for\_inrange(2):96out\_features//=297layers+=[98nn.Upsample(scale\_factor=2),99nn.Conv2d(in\_features,out\_features,kernel\_size=3,stride=1,padding=1),100nn.InstanceNorm2d(out\_features),101nn.ReLU(inplace=True),102]103in\_features=out\_features

#

Finally we map the feature map to an RGB image

106layers+=[nn.Conv2d(out\_features,input\_channels,7,padding=3,padding\_mode='reflect'),nn.Tanh()]

#

Create a sequential module with the layers

109self.layers=nn.Sequential(\*layers)

#

Initialize weights to N(0,0.2)

112self.apply(weights\_init\_normal)

#

114defforward(self,x):115returnself.layers(x)

#

This is the residual block, with two convolution layers.

118classResidualBlock(nn.Module):

#

123def\_\_init\_\_(self,in\_features:int):124super().\_\_init\_\_()125self.block=nn.Sequential(126nn.Conv2d(in\_features,in\_features,kernel\_size=3,padding=1,padding\_mode='reflect'),127nn.InstanceNorm2d(in\_features),128nn.ReLU(inplace=True),129nn.Conv2d(in\_features,in\_features,kernel\_size=3,padding=1,padding\_mode='reflect'),130nn.InstanceNorm2d(in\_features),131nn.ReLU(inplace=True),132)

#

134defforward(self,x:torch.Tensor):135returnx+self.block(x)

#

This is the discriminator.

138classDiscriminator(nn.Module):

#

143def\_\_init\_\_(self,input\_shape:Tuple[int,int,int]):144super().\_\_init\_\_()145channels,height,width=input\_shape

#

Output of the discriminator is also a map of probabilities, whether each region of the image is real or generated

149self.output\_shape=(1,height//2\*\*4,width//2\*\*4)150151self.layers=nn.Sequential(

#

Each of these blocks will shrink the height and width by a factor of 2

153DiscriminatorBlock(channels,64,normalize=False),154DiscriminatorBlock(64,128),155DiscriminatorBlock(128,256),156DiscriminatorBlock(256,512),

#

Zero pad on top and left to keep the output height and width same with the 4×4 kernel

159nn.ZeroPad2d((1,0,1,0)),160nn.Conv2d(512,1,kernel\_size=4,padding=1)161)

#

Initialize weights to N(0,0.2)

164self.apply(weights\_init\_normal)

#

166defforward(self,img):167returnself.layers(img)

#

This is the discriminator block module. It does a convolution, an optional normalization, and a leaky ReLU.

It shrinks the height and width of the input feature map by half.

170classDiscriminatorBlock(nn.Module):

#

178def\_\_init\_\_(self,in\_filters:int,out\_filters:int,normalize:bool=True):179super().\_\_init\_\_()180layers=[nn.Conv2d(in\_filters,out\_filters,kernel\_size=4,stride=2,padding=1)]181ifnormalize:182layers.append(nn.InstanceNorm2d(out\_filters))183layers.append(nn.LeakyReLU(0.2,inplace=True))184self.layers=nn.Sequential(\*layers)

#

186defforward(self,x:torch.Tensor):187returnself.layers(x)

#

Initialize convolution layer weights to N(0,0.2)

190defweights\_init\_normal(m):

#

194classname=m.\_\_class\_\_.\_\_name\_\_195ifclassname.find("Conv")!=-1:196torch.nn.init.normal\_(m.weight.data,0.0,0.02)

#

Load an image and change to RGB if in grey-scale.

199defload\_image(path:str):

#

203image=Image.open(path)204ifimage.mode!='RGB':205image=Image.new("RGB",image.size).paste(image)206207returnimage

#

Dataset to load images

210classImageDataset(Dataset):

#

Download dataset and extract data

215@staticmethod216defdownload(dataset\_name:str):

#

URL

221url=f'https://people.eecs.berkeley.edu/~taesung\_park/CycleGAN/datasets/{dataset\_name}.zip'

#

Download folder

223root=lab.get\_data\_path()/'cycle\_gan'224ifnotroot.exists():225root.mkdir(parents=True)

#

Download destination

227archive=root/f'{dataset\_name}.zip'

#

Download file (generally ~100MB)

229download\_file(url,archive)

#

Extract the archive

231withzipfile.ZipFile(archive,'r')asf:232f.extractall(root)

#

Initialize the dataset

  • dataset_name is the name of the dataset
  • transforms_ is the set of image transforms
  • mode is either train or test
234def\_\_init\_\_(self,dataset\_name:str,transforms\_,mode:str):

#

Dataset path

243root=lab.get\_data\_path()/'cycle\_gan'/dataset\_name

#

Download if missing

245ifnotroot.exists():246self.download(dataset\_name)

#

Image transforms

249self.transform=transforms.Compose(transforms\_)

#

Get image paths

252path\_a=root/f'{mode}A'253path\_b=root/f'{mode}B'254self.files\_a=sorted(str(f)forfinpath\_a.iterdir())255self.files\_b=sorted(str(f)forfinpath\_b.iterdir())

#

257def\_\_getitem\_\_(self,index):

#

Return a pair of images. These pairs get batched together, and they do not act like pairs in training. So it is kind of ok that we always keep giving the same pair.

261return{"x":self.transform(load\_image(self.files\_a[index%len(self.files\_a)])),262"y":self.transform(load\_image(self.files\_b[index%len(self.files\_b)]))}

#

264def\_\_len\_\_(self):

#

Number of images in the dataset

266returnmax(len(self.files\_a),len(self.files\_b))

#

Replay Buffer

Replay buffer is used to train the discriminator. Generated images are added to the replay buffer and sampled from it.

The replay buffer returns the newly added image with a probability of 0.5. Otherwise, it sends an older generated image and replaces the older image with the newly generated image.

This is done to reduce model oscillation.

269classReplayBuffer:

#

283def\_\_init\_\_(self,max\_size:int=50):284self.max\_size=max\_size285self.data=[]

#

Add/retrieve an image

287defpush\_and\_pop(self,data:torch.Tensor):

#

289data=data.detach()290res=[]291forelementindata:292iflen(self.data)\<self.max\_size:293self.data.append(element)294res.append(element)295else:296ifrandom.uniform(0,1)\>0.5:297i=random.randint(0,self.max\_size-1)298res.append(self.data[i].clone())299self.data[i]=element300else:301res.append(element)302returntorch.stack(res)

#

Configurations

305classConfigs(BaseConfigs):

#

DeviceConfigs will pick a GPU if available

309device:torch.device=DeviceConfigs()

#

Hyper-parameters

312epochs:int=200313dataset\_name:str='monet2photo'314batch\_size:int=1315316data\_loader\_workers=8317318learning\_rate=0.0002319adam\_betas=(0.5,0.999)320decay\_start=100

#

The paper suggests using a least-squares loss instead of negative log-likelihood, at it is found to be more stable.

324gan\_loss=torch.nn.MSELoss()

#

L1 loss is used for cycle loss and identity loss

327cycle\_loss=torch.nn.L1Loss()328identity\_loss=torch.nn.L1Loss()

#

Image dimensions

331img\_height=256332img\_width=256333img\_channels=3

#

Number of residual blocks in the generator

336n\_residual\_blocks=9

#

Loss coefficients

339cyclic\_loss\_coefficient=10.0340identity\_loss\_coefficient=5.341342sample\_interval=500

#

Models

345generator\_xy:GeneratorResNet346generator\_yx:GeneratorResNet347discriminator\_x:Discriminator348discriminator\_y:Discriminator

#

Optimizers

351generator\_optimizer:torch.optim.Adam352discriminator\_optimizer:torch.optim.Adam

#

Learning rate schedules

355generator\_lr\_scheduler:torch.optim.lr\_scheduler.LambdaLR356discriminator\_lr\_scheduler:torch.optim.lr\_scheduler.LambdaLR

#

Data loaders

359dataloader:DataLoader360valid\_dataloader:DataLoader

#

Generate samples from test set and save them

362defsample\_images(self,n:int):

#

364batch=next(iter(self.valid\_dataloader))365self.generator\_xy.eval()366self.generator\_yx.eval()367withtorch.no\_grad():368data\_x,data\_y=batch['x'].to(self.generator\_xy.device),batch['y'].to(self.generator\_yx.device)369gen\_y=self.generator\_xy(data\_x)370gen\_x=self.generator\_yx(data\_y)

#

Arrange images along x-axis

373data\_x=make\_grid(data\_x,nrow=5,normalize=True)374data\_y=make\_grid(data\_y,nrow=5,normalize=True)375gen\_x=make\_grid(gen\_x,nrow=5,normalize=True)376gen\_y=make\_grid(gen\_y,nrow=5,normalize=True)

#

Arrange images along y-axis

379image\_grid=torch.cat((data\_x,gen\_y,data\_y,gen\_x),1)

#

Show samples

382plot\_image(image\_grid)

#

Initialize models and data loaders

384definitialize(self):

#

388input\_shape=(self.img\_channels,self.img\_height,self.img\_width)

#

Create the models

391self.generator\_xy=GeneratorResNet(self.img\_channels,self.n\_residual\_blocks).to(self.device)392self.generator\_yx=GeneratorResNet(self.img\_channels,self.n\_residual\_blocks).to(self.device)393self.discriminator\_x=Discriminator(input\_shape).to(self.device)394self.discriminator\_y=Discriminator(input\_shape).to(self.device)

#

Create the optmizers

397self.generator\_optimizer=torch.optim.Adam(398itertools.chain(self.generator\_xy.parameters(),self.generator\_yx.parameters()),399lr=self.learning\_rate,betas=self.adam\_betas)400self.discriminator\_optimizer=torch.optim.Adam(401itertools.chain(self.discriminator\_x.parameters(),self.discriminator\_y.parameters()),402lr=self.learning\_rate,betas=self.adam\_betas)

#

Create the learning rate schedules. The learning rate stars flat until decay_start epochs, and then linearly reduce to 0 at end of training.

407decay\_epochs=self.epochs-self.decay\_start408self.generator\_lr\_scheduler=torch.optim.lr\_scheduler.LambdaLR(409self.generator\_optimizer,lr\_lambda=lambdae:1.0-max(0,e-self.decay\_start)/decay\_epochs)410self.discriminator\_lr\_scheduler=torch.optim.lr\_scheduler.LambdaLR(411self.discriminator\_optimizer,lr\_lambda=lambdae:1.0-max(0,e-self.decay\_start)/decay\_epochs)

#

Image transformations

414transforms\_=[415transforms.Resize(int(self.img\_height\*1.12),InterpolationMode.BICUBIC),416transforms.RandomCrop((self.img\_height,self.img\_width)),417transforms.RandomHorizontalFlip(),418transforms.ToTensor(),419transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),420]

#

Training data loader

423self.dataloader=DataLoader(424ImageDataset(self.dataset\_name,transforms\_,'train'),425batch\_size=self.batch\_size,426shuffle=True,427num\_workers=self.data\_loader\_workers,428)

#

Validation data loader

431self.valid\_dataloader=DataLoader(432ImageDataset(self.dataset\_name,transforms\_,"test"),433batch\_size=5,434shuffle=True,435num\_workers=self.data\_loader\_workers,436)

#

Training

We aim to solve: G∗,F∗=argG,Fmin​DX​,DY​max​L(G,F,DX​,DY​)

where, G translates images from X→Y, F translates images from Y→X, DX​ tests if images are from X space, DY​ tests if images are from Y space, and

L(G,F,DX​,DY​)LGAN​(G,F,DY​,X,Y)Lcyc​(G,F)Lidentity​(G,F)​=LGAN​(G,DY​,X,Y)+LGAN​(F,DX​,Y,X)+λ1​Lcyc​(G,F)+λ2​Lidentity​(G,F)=Ey∼pdata​(y)​[logDY​(y)]+Ex∼pdata​(x)​[log(1−DY​(G(x)))]+Ex∼pdata​(x)​[logDX​(x)]+Ey∼pdata​(y)​[log(1−DX​(F(y)))]=Ex∼pdata​(x)​[∥F(G(x))−x∥1​]+Ey∼pdata​(y)​[∥G(F(y))−y∥1​]=Ex∼pdata​(x)​[∥F(x)−x∥1​]+Ey∼pdata​(y)​[∥G(y)−y∥1​]​

LGAN​ is the generative adversarial loss from the original GAN paper.

Lcyc​ is the cyclic loss, where we try to get F(G(x)) to be similar to x, and G(F(y)) to be similar to y. Basically if the two generators (transformations) are applied in series it should give back the original image. This is the main contribution of this paper. It trains the generators to generate an image of the other distribution that is similar to the original image. Without this loss G(x) could generate anything that's from the distribution of Y. Now it needs to generate something from the distribution of Y but still has properties of x, so that F(G(x) can re-generate something like x.

Lcyc​ is the identity loss. This was used to encourage the mapping to preserve color composition between the input and the output.

To solve G∗,F∗, discriminators DX​ and DY​ should ascend on the gradient,

∇θDX​,DY​​​m1​i=1∑m​​[logDY​(y(i))+log(1−DY​(G(x(i))))+logDX​(x(i))+log(1−DX​(F(y(i))))]​

That is descend on negative log-likelihood loss.

In order to stabilize the training the negative log- likelihood objective was replaced by a least-squared loss - the least-squared error of discriminator, labelling real images with 1, and generated images with 0. So we want to descend on the gradient,

∇θDX​,DY​​​m1​i=1∑m​​[(DY​(y(i))−1)2+DY​(G(x(i)))2+(DX​(x(i))−1)2+DX​(F(y(i)))2]​

We use least-squares for generators also. The generators should descend on the gradient,

∇θF,G​​m1​i=1∑m​​[(DY​(G(x(i)))−1)2+(DX​(F(y(i)))−1)2+Lcyc​(G,F)+Lidentity​(G,F)]​

We use generator_xy for G and generator_yx for F. We use discriminator_x for DX​ and discriminator_y for DY​.

438defrun(self):

#

Replay buffers to keep generated samples

540gen\_x\_buffer=ReplayBuffer()541gen\_y\_buffer=ReplayBuffer()

#

Loop through epochs

544forepochinmonit.loop(self.epochs):

#

Loop through the dataset

546fori,batchinmonit.enum('Train',self.dataloader):

#

Move images to the device

548data\_x,data\_y=batch['x'].to(self.device),batch['y'].to(self.device)

#

true labels equal to 1

551true\_labels=torch.ones(data\_x.size(0),\*self.discriminator\_x.output\_shape,552device=self.device,requires\_grad=False)

#

false labels equal to 0

554false\_labels=torch.zeros(data\_x.size(0),\*self.discriminator\_x.output\_shape,555device=self.device,requires\_grad=False)

#

Train the generators. This returns the generated images.

559gen\_x,gen\_y=self.optimize\_generators(data\_x,data\_y,true\_labels)

#

Train discriminators

562self.optimize\_discriminator(data\_x,data\_y,563gen\_x\_buffer.push\_and\_pop(gen\_x),gen\_y\_buffer.push\_and\_pop(gen\_y),564true\_labels,false\_labels)

#

Save training statistics and increment the global step counter

567tracker.save()568tracker.add\_global\_step(max(len(data\_x),len(data\_y)))

#

Save images at intervals

571batches\_done=epoch\*len(self.dataloader)+i572ifbatches\_done%self.sample\_interval==0:

#

Sample images

574self.sample\_images(batches\_done)

#

Update learning rates

577self.generator\_lr\_scheduler.step()578self.discriminator\_lr\_scheduler.step()

#

New line

580tracker.new\_line()

#

Optimize the generators with identity, gan and cycle losses.

582defoptimize\_generators(self,data\_x:torch.Tensor,data\_y:torch.Tensor,true\_labels:torch.Tensor):

#

Change to training mode

588self.generator\_xy.train()589self.generator\_yx.train()

#

Identity loss ∥F(G(x(i)))−x(i)∥1​ ∥G(F(y(i)))−y(i)∥1​

594loss\_identity=(self.identity\_loss(self.generator\_yx(data\_x),data\_x)+595self.identity\_loss(self.generator\_xy(data\_y),data\_y))

#

Generate images G(x) and F(y)

598gen\_y=self.generator\_xy(data\_x)599gen\_x=self.generator\_yx(data\_y)

#

GAN loss (DY​(G(x(i)))−1)2+(DX​(F(y(i)))−1)2

604loss\_gan=(self.gan\_loss(self.discriminator\_y(gen\_y),true\_labels)+605self.gan\_loss(self.discriminator\_x(gen\_x),true\_labels))

#

Cycle loss ∥F(G(x(i)))−x(i)∥1​+∥G(F(y(i)))−y(i)∥1​

612loss\_cycle=(self.cycle\_loss(self.generator\_yx(gen\_y),data\_x)+613self.cycle\_loss(self.generator\_xy(gen\_x),data\_y))

#

Total loss

616loss\_generator=(loss\_gan+617self.cyclic\_loss\_coefficient\*loss\_cycle+618self.identity\_loss\_coefficient\*loss\_identity)

#

Take a step in the optimizer

621self.generator\_optimizer.zero\_grad()622loss\_generator.backward()623self.generator\_optimizer.step()

#

Log losses

626tracker.add({'loss.generator':loss\_generator,627'loss.generator.cycle':loss\_cycle,628'loss.generator.gan':loss\_gan,629'loss.generator.identity':loss\_identity})

#

Return generated images

632returngen\_x,gen\_y

#

Optimize the discriminators with gan loss.

634defoptimize\_discriminator(self,data\_x:torch.Tensor,data\_y:torch.Tensor,635gen\_x:torch.Tensor,gen\_y:torch.Tensor,636true\_labels:torch.Tensor,false\_labels:torch.Tensor):

#

GAN Loss

(DY​(y(i))−1)2+DY​(G(x(i)))2+(DX​(x(i))−1)2+DX​(F(y(i)))2​

649loss\_discriminator=(self.gan\_loss(self.discriminator\_x(data\_x),true\_labels)+650self.gan\_loss(self.discriminator\_x(gen\_x),false\_labels)+651self.gan\_loss(self.discriminator\_y(data\_y),true\_labels)+652self.gan\_loss(self.discriminator\_y(gen\_y),false\_labels))

#

Take a step in the optimizer

655self.discriminator\_optimizer.zero\_grad()656loss\_discriminator.backward()657self.discriminator\_optimizer.step()

#

Log losses

660tracker.add({'loss.discriminator':loss\_discriminator})

#

Train Cycle GAN

663deftrain():

#

Create configurations

668conf=Configs()

#

Create an experiment

670experiment.create(name='cycle\_gan')

#

Calculate configurations. It will calculate conf.run and all other configs required by it.

673experiment.configs(conf,{'dataset\_name':'summer2winter\_yosemite'})674conf.initialize()

#

Register models for saving and loading. get_modules gives a dictionary of nn.Modules in conf . You can also specify a custom dictionary of models.

679experiment.add\_pytorch\_models(get\_modules(conf))

#

Start and watch the experiment

681withexperiment.start():

#

Run the training

683conf.run()

#

Plot an image with matplotlib

686defplot\_image(img:torch.Tensor):

#

690frommatplotlibimportpyplotasplt

#

Move tensor to CPU

693img=img.cpu()

#

Get min and max values of the image for normalization

695img\_min,img\_max=img.min(),img.max()

#

Scale image values to be 0...1

697img=(img-img\_min)/(img\_max-img\_min+1e-5)

#

We have to change the order of dimensions to HWC.

699img=img.permute(1,2,0)

#

Show Image

701plt.imshow(img)

#

We don't need axes

703plt.axis('off')

#

Display

705plt.show()

#

Evaluate trained Cycle GAN

708defevaluate():

#

Set the run UUID from the training run

713trained\_run\_uuid='f73c1164184711eb9190b74249275441'

#

Create configs object

715conf=Configs()

#

Create experiment

717experiment.create(name='cycle\_gan\_inference')

#

Load hyper parameters set for training

719conf\_dict=experiment.load\_configs(trained\_run\_uuid)

#

Calculate configurations. We specify the generators 'generator_xy', 'generator_yx' so that it only loads those and their dependencies. Configs like device and img_channels will be calculated, since these are required by generator_xy and generator_yx .

If you want other parameters like dataset_name you should specify them here. If you specify nothing, all the configurations will be calculated, including data loaders. Calculation of configurations and their dependencies will happen when you call experiment.start

728experiment.configs(conf,conf\_dict)729conf.initialize()

#

Register models for saving and loading. get_modules gives a dictionary of nn.Modules in conf . You can also specify a custom dictionary of models.

734experiment.add\_pytorch\_models(get\_modules(conf))

#

Specify which run to load from. Loading will actually happen when you call experiment.start

737experiment.load(trained\_run\_uuid)

#

Start the experiment

740withexperiment.start():

#

Image transformations

742transforms\_=[743transforms.ToTensor(),744transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),745]

#

Load your own data. Here we try the test set. I was trying with Yosemite photos, they look awesome. You can use conf.dataset_name , if you specified dataset_name as something you wanted to be calculated in the call to experiment.configs

751dataset=ImageDataset(conf.dataset\_name,transforms\_,'train')

#

Get an image from dataset

753x\_image=dataset[10]['x']

#

Display the image

755plot\_image(x\_image)

#

Evaluation mode

758conf.generator\_xy.eval()759conf.generator\_yx.eval()

#

We don't need gradients

762withtorch.no\_grad():

#

Add batch dimension and move to the device we use

764data=x\_image.unsqueeze(0).to(conf.device)765generated\_y=conf.generator\_xy(data)

#

Display the generated image.

768plot\_image(generated\_y[0].cpu())769770771if\_\_name\_\_=='\_\_main\_\_':772train()

#

evaluate()

labml.ai