Back to Annotated Deep Learning Paper Implementations

StyleGAN 2

docs/gan/stylegan/index.html

latest34.6 KB
Original Source

homeganstylegan

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/gan/stylegan/ init.py)

#

StyleGAN 2

This is a PyTorch implementation of the paper Analyzing and Improving the Image Quality of StyleGAN which introduces StyleGAN 2. StyleGAN 2 is an improvement over StyleGAN from the paper A Style-Based Generator Architecture for Generative Adversarial Networks. And StyleGAN is based on Progressive GAN from the paper Progressive Growing of GANs for Improved Quality, Stability, and Variation. All three papers are from the same authors from NVIDIA AI.

Our implementation is a minimalistic StyleGAN 2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.

🏃 Here's the training code: experiment.py.

These are 64×64 images generated after training for about 80K steps.

We'll first introduce the three papers at a high level.

Generative Adversarial Networks

Generative adversarial networks have two components; the generator and the discriminator. The generator network takes a random latent vector (z∈Z) and tries to generate a realistic image. The discriminator network tries to differentiate the real images from generated images. When we train the two networks together the generator starts generating images indistinguishable from real images.

Progressive GAN

Progressive GAN generates high-resolution images (1080×1080) of size. It does so by progressively increasing the image size. First, it trains a network that produces a 4×4 image, then 8×8 , then an 16×16 image, and so on up to the desired image resolution.

At each resolution, the generator network produces an image in latent space which is converted into RGB, with a 1×1 convolution. When we progress from a lower resolution to a higher resolution (say from 4×4 to 8×8 ) we scale the latent image by 2× and add a new block (two 3×3 convolution layers) and a new 1×1 layer to get RGB. The transition is done smoothly by adding a residual connection to the 2× scaled 4×4 RGB image. The weight of this residual connection is slowly reduced, to let the new block take over.

The discriminator is a mirror image of the generator network. The progressive growth of the discriminator is done similarly.

2× and 0.5× denote feature map resolution scaling and scaling. 4×4, 8×4, ... denote feature map resolution at the generator or discriminator block. Each discriminator and generator block consists of 2 convolution layers with leaky ReLU activations.

They use minibatch standard deviation to increase variation and equalized learning rate which we discussed below in the implementation. They also use pixel-wise normalization where at each pixel the feature vector is normalized. They apply this to all the convolution layer outputs (except RGB).

StyleGAN

StyleGAN improves the generator of Progressive GAN keeping the discriminator architecture the same.

Mapping Network

It maps the random latent vector (z∈Z) into a different latent space (w∈W), with an 8-layer neural network. This gives an intermediate latent space W where the factors of variations are more linear (disentangled).

AdaIN

Then w is transformed into two vectors ( styles ) per layer, i, yi​=(ys,i​,yb,i​)=fAi​​(w) and used for scaling and shifting (biasing) in each layer with AdaIN operator (normalize and scale): AdaIN(xi​,yi​)=ys,i​σ(xi​)xi​−μ(xi​)​+yb,i​

Style Mixing

To prevent the generator from assuming adjacent styles are correlated, they randomly use different styles for different blocks. That is, they sample two latent vectors (z1​,z2​) and corresponding (w1​,w2​) and use w1​ based styles for some blocks and w2​ based styles for some blacks randomly.

Stochastic Variation

Noise is made available to each block which helps the generator create more realistic images. Noise is scaled per channel by a learned weight.

Bilinear Up and Down Sampling

All the up and down-sampling operations are accompanied by bilinear smoothing.

A denotes a linear layer. B denotes a broadcast and scaling operation (noise is a single channel). StyleGAN also uses progressive growing like Progressive GAN.

StyleGAN 2

StyleGAN 2 changes both the generator and the discriminator of StyleGAN.

Weight Modulation and Demodulation

They remove the AdaIN operator and replace it with the weight modulation and demodulation step. This is supposed to improve what they call droplet artifacts that are present in generated images, which are caused by the normalization in AdaIN operator. Style vector per layer is calculated from wi​∈W as si​=fAi​​(wi​).

Then the convolution weights w are modulated as follows. (w here on refers to weights not intermediate latent space, we are sticking to the same notation as the paper.)

wi,j,k′​=si​⋅wi,j,k​ Then it's demodulated by normalizing, wi,j,k′′​=∑i,k​wi,j,k′​2+ϵ​wi,j,k′​​ where i is the input channel, j is the output channel, and k is the kernel index.

Path Length Regularization

Path length regularization encourages a fixed-size step in W to result in a non-zero, fixed-magnitude change in the generated image.

No Progressive Growing

StyleGAN2 uses residual connections (with down-sampling) in the discriminator and skip connections in the generator with up-sampling (the RGB outputs from each layer are added - no residual connections in feature maps). They show that with experiments that the contribution of low-resolution layers is higher at beginning of the training and then high-resolution layers take over.

148importmath149fromtypingimportTuple,Optional,List150151importnumpyasnp152importtorch153importtorch.nn.functionalasF154importtorch.utils.data155fromtorchimportnn

#

Mapping Network

This is an MLP with 8 linear layers. The mapping network maps the latent vector z∈W to an intermediate latent space w∈W. W space will be disentangled from the image space where the factors of variation become more linear.

158classMappingNetwork(nn.Module):

#

  • features is the number of features in z and w
  • n_layers is the number of layers in the mapping network.
173def\_\_init\_\_(self,features:int,n\_layers:int):

#

178super().\_\_init\_\_()

#

Create the MLP

181layers=[]182foriinrange(n\_layers):

#

Equalized learning-rate linear layers

184layers.append(EqualizedLinear(features,features))

#

Leaky Relu

186layers.append(nn.LeakyReLU(negative\_slope=0.2,inplace=True))187188self.net=nn.Sequential(\*layers)

#

190defforward(self,z:torch.Tensor):

#

Normalize z

192z=F.normalize(z,dim=1)

#

Map z to w

194returnself.net(z)

#

StyleGAN2 Generator

A denotes a linear layer. B denotes a broadcast and scaling operation (noise is a single channel). toRGB also has a style modulation which is not shown in the diagram to keep it simple.

The generator starts with a learned constant. Then it has a series of blocks. The feature map resolution is doubled at each block Each block outputs an RGB image and they are scaled up and summed to get the final RGB image.

197classGenerator(nn.Module):

#

  • log_resolution is the log2​ of image resolution
  • d_latent is the dimensionality of w
  • n_features number of features in the convolution layer at the highest resolution (final block)
  • max_features maximum number of features in any generator block
214def\_\_init\_\_(self,log\_resolution:int,d\_latent:int,n\_features:int=32,max\_features:int=512):

#

221super().\_\_init\_\_()

#

Calculate the number of features for each block

Something like [512, 512, 256, 128, 64, 32]

226features=[min(max\_features,n\_features\*(2\*\*i))foriinrange(log\_resolution-2,-1,-1)]

#

Number of generator blocks

228self.n\_blocks=len(features)

#

Trainable 4×4 constant

231self.initial\_constant=nn.Parameter(torch.randn((1,features[0],4,4)))

#

First style block for 4×4 resolution and layer to get RGB

234self.style\_block=StyleBlock(d\_latent,features[0],features[0])235self.to\_rgb=ToRGB(d\_latent,features[0])

#

Generator blocks

238blocks=[GeneratorBlock(d\_latent,features[i-1],features[i])foriinrange(1,self.n\_blocks)]239self.blocks=nn.ModuleList(blocks)

#

2× up sampling layer. The feature space is up sampled at each block

243self.up\_sample=UpSample()

#

  • w is w. In order to mix-styles (use different w for different layers), we provide a separate w for each generator block. It has shape [n_blocks, batch_size, d_latent] .
  • input_noise is the noise for each block. It's a list of pairs of noise sensors because each block (except the initial) has two noise inputs after each convolution layer (see the diagram).
245defforward(self,w:torch.Tensor,input\_noise:List[Tuple[Optional[torch.Tensor],Optional[torch.Tensor]]]):

#

Get batch size

255batch\_size=w.shape[1]

#

Expand the learned constant to match batch size

258x=self.initial\_constant.expand(batch\_size,-1,-1,-1)

#

The first style block

261x=self.style\_block(x,w[0],input\_noise[0][1])

#

Get first rgb image

263rgb=self.to\_rgb(x,w[0])

#

Evaluate rest of the blocks

266foriinrange(1,self.n\_blocks):

#

Up sample the feature map

268x=self.up\_sample(x)

#

Run it through the generator block

270x,rgb\_new=self.blocks[i-1](x,w[i],input\_noise[i])

#

Up sample the RGB image and add to the rgb from the block

272rgb=self.up\_sample(rgb)+rgb\_new

#

Return the final RGB image

275returnrgb

#

Generator Block

A denotes a linear layer. B denotes a broadcast and scaling operation (noise is a single channel). toRGB also has a style modulation which is not shown in the diagram to keep it simple.

The generator block consists of two style blocks (3×3 convolutions with style modulation) and an RGB output.

278classGeneratorBlock(nn.Module):

#

  • d_latent is the dimensionality of w
  • in_features is the number of features in the input feature map
  • out_features is the number of features in the output feature map
294def\_\_init\_\_(self,d\_latent:int,in\_features:int,out\_features:int):

#

300super().\_\_init\_\_()

#

First style block changes the feature map size to out_features

303self.style\_block1=StyleBlock(d\_latent,in\_features,out\_features)

#

Second style block

305self.style\_block2=StyleBlock(d\_latent,out\_features,out\_features)

#

toRGB layer

308self.to\_rgb=ToRGB(d\_latent,out\_features)

#

  • x is the input feature map of shape [batch_size, in_features, height, width]
  • w is w with shape [batch_size, d_latent]
  • noise is a tuple of two noise tensors of shape [batch_size, 1, height, width]
310defforward(self,x:torch.Tensor,w:torch.Tensor,noise:Tuple[Optional[torch.Tensor],Optional[torch.Tensor]]):

#

First style block with first noise tensor. The output is of shape [batch_size, out_features, height, width]

318x=self.style\_block1(x,w,noise[0])

#

Second style block with second noise tensor. The output is of shape [batch_size, out_features, height, width]

321x=self.style\_block2(x,w,noise[1])

#

Get RGB image

324rgb=self.to\_rgb(x,w)

#

Return feature map and rgb image

327returnx,rgb

#

Style Block

A denotes a linear layer. B denotes a broadcast and scaling operation (noise is single channel).

Style block has a weight modulation convolution layer.

330classStyleBlock(nn.Module):

#

  • d_latent is the dimensionality of w
  • in_features is the number of features in the input feature map
  • out_features is the number of features in the output feature map
344def\_\_init\_\_(self,d\_latent:int,in\_features:int,out\_features:int):

#

350super().\_\_init\_\_()

#

Get style vector from w (denoted by A in the diagram) with an equalized learning-rate linear layer

353self.to\_style=EqualizedLinear(d\_latent,in\_features,bias=1.0)

#

Weight modulated convolution layer

355self.conv=Conv2dWeightModulate(in\_features,out\_features,kernel\_size=3)

#

Noise scale

357self.scale\_noise=nn.Parameter(torch.zeros(1))

#

Bias

359self.bias=nn.Parameter(torch.zeros(out\_features))

#

Activation function

362self.activation=nn.LeakyReLU(0.2,True)

#

  • x is the input feature map of shape [batch_size, in_features, height, width]
  • w is w with shape [batch_size, d_latent]
  • noise is a tensor of shape [batch_size, 1, height, width]
364defforward(self,x:torch.Tensor,w:torch.Tensor,noise:Optional[torch.Tensor]):

#

Get style vector s

371s=self.to\_style(w)

#

Weight modulated convolution

373x=self.conv(x,s)

#

Scale and add noise

375ifnoiseisnotNone:376x=x+self.scale\_noise[None,:,None,None]\*noise

#

Add bias and evaluate activation function

378returnself.activation(x+self.bias[None,:,None,None])

#

To RGB

A denotes a linear layer.

Generates an RGB image from a feature map using 1×1 convolution.

381classToRGB(nn.Module):

#

  • d_latent is the dimensionality of w
  • features is the number of features in the feature map
394def\_\_init\_\_(self,d\_latent:int,features:int):

#

399super().\_\_init\_\_()

#

Get style vector from w (denoted by A in the diagram) with an equalized learning-rate linear layer

402self.to\_style=EqualizedLinear(d\_latent,features,bias=1.0)

#

Weight modulated convolution layer without demodulation

405self.conv=Conv2dWeightModulate(features,3,kernel\_size=1,demodulate=False)

#

Bias

407self.bias=nn.Parameter(torch.zeros(3))

#

Activation function

409self.activation=nn.LeakyReLU(0.2,True)

#

  • x is the input feature map of shape [batch_size, in_features, height, width]
  • w is w with shape [batch_size, d_latent]
411defforward(self,x:torch.Tensor,w:torch.Tensor):

#

Get style vector s

417style=self.to\_style(w)

#

Weight modulated convolution

419x=self.conv(x,style)

#

Add bias and evaluate activation function

421returnself.activation(x+self.bias[None,:,None,None])

#

Convolution with Weight Modulation and Demodulation

This layer scales the convolution weights by the style vector and demodulates by normalizing it.

424classConv2dWeightModulate(nn.Module):

#

  • in_features is the number of features in the input feature map
  • out_features is the number of features in the output feature map
  • kernel_size is the size of the convolution kernel
  • demodulate is flag whether to normalize weights by its standard deviation
  • eps is the ϵ for normalizing
431def\_\_init\_\_(self,in\_features:int,out\_features:int,kernel\_size:int,432demodulate:float=True,eps:float=1e-8):

#

440super().\_\_init\_\_()

#

Number of output features

442self.out\_features=out\_features

#

Whether to normalize weights

444self.demodulate=demodulate

#

Padding size

446self.padding=(kernel\_size-1)//2

#

Weights parameter with equalized learning rate

449self.weight=EqualizedWeight([out\_features,in\_features,kernel\_size,kernel\_size])

#

ϵ

451self.eps=eps

#

  • x is the input feature map of shape [batch_size, in_features, height, width]
  • s is style based scaling tensor of shape [batch_size, in_features]
453defforward(self,x:torch.Tensor,s:torch.Tensor):

#

Get batch size, height and width

460b,\_,h,w=x.shape

#

Reshape the scales

463s=s[:,None,:,None,None]

#

Get learning rate equalized weights

465weights=self.weight()[None,:,:,:,:]

#

w‘i,j,k​=si​∗wi,j,k​ where i is the input channel, j is the output channel, and k is the kernel index.

The result has shape [batch_size, out_features, in_features, kernel_size, kernel_size]

470weights=weights\*s

#

Demodulate

473ifself.demodulate:

#

σj​=i,k∑​(wi,j,k′​)2+ϵ​

475sigma\_inv=torch.rsqrt((weights\*\*2).sum(dim=(2,3,4),keepdim=True)+self.eps)

#

wi,j,k′′​=∑i,k​(wi,j,k′​)2+ϵ​wi,j,k′​​

477weights=weights\*sigma\_inv

#

Reshape x

480x=x.reshape(1,-1,h,w)

#

Reshape weights

483\_,\_,\*ws=weights.shape484weights=weights.reshape(b\*self.out\_features,\*ws)

#

Use grouped convolution to efficiently calculate the convolution with sample wise kernel. i.e. we have a different kernel (weights) for each sample in the batch

488x=F.conv2d(x,weights,padding=self.padding,groups=b)

#

Reshape x to [batch_size, out_features, height, width] and return

491returnx.reshape(-1,self.out\_features,h,w)

#

StyleGAN 2 Discriminator

Discriminator first transforms the image to a feature map of the same resolution and then runs it through a series of blocks with residual connections. The resolution is down-sampled by 2× at each block while doubling the number of features.

494classDiscriminator(nn.Module):

#

  • log_resolution is the log2​ of image resolution
  • n_features number of features in the convolution layer at the highest resolution (first block)
  • max_features maximum number of features in any generator block
508def\_\_init\_\_(self,log\_resolution:int,n\_features:int=64,max\_features:int=512):

#

514super().\_\_init\_\_()

#

Layer to convert RGB image to a feature map with n_features number of features.

517self.from\_rgb=nn.Sequential(518EqualizedConv2d(3,n\_features,1),519nn.LeakyReLU(0.2,True),520)

#

Calculate the number of features for each block.

Something like [64, 128, 256, 512, 512, 512] .

525features=[min(max\_features,n\_features\*(2\*\*i))foriinrange(log\_resolution-1)]

#

Number of discirminator blocks

527n\_blocks=len(features)-1

#

Discriminator blocks

529blocks=[DiscriminatorBlock(features[i],features[i+1])foriinrange(n\_blocks)]530self.blocks=nn.Sequential(\*blocks)

#

Mini-batch Standard Deviation

533self.std\_dev=MiniBatchStdDev()

#

Number of features after adding the standard deviations map

535final\_features=features[-1]+1

#

Final 3×3 convolution layer

537self.conv=EqualizedConv2d(final\_features,final\_features,3)

#

Final linear layer to get the classification

539self.final=EqualizedLinear(2\*2\*final\_features,1)

#

  • x is the input image of shape [batch_size, 3, height, width]
541defforward(self,x:torch.Tensor):

#

Try to normalize the image (this is totally optional, but sped up the early training a little)

547x=x-0.5

#

Convert from RGB

549x=self.from\_rgb(x)

#

Run through the discriminator blocks

551x=self.blocks(x)

#

Calculate and append mini-batch standard deviation

554x=self.std\_dev(x)

#

3×3 convolution

556x=self.conv(x)

#

Flatten

558x=x.reshape(x.shape[0],-1)

#

Return the classification score

560returnself.final(x)

#

Discriminator Block

Discriminator block consists of two 3×3 convolutions with a residual connection.

563classDiscriminatorBlock(nn.Module):

#

  • in_features is the number of features in the input feature map
  • out_features is the number of features in the output feature map
574def\_\_init\_\_(self,in\_features,out\_features):

#

579super().\_\_init\_\_()

#

Down-sampling and 1×1 convolution layer for the residual connection

581self.residual=nn.Sequential(DownSample(),582EqualizedConv2d(in\_features,out\_features,kernel\_size=1))

#

Two 3×3 convolutions

585self.block=nn.Sequential(586EqualizedConv2d(in\_features,in\_features,kernel\_size=3,padding=1),587nn.LeakyReLU(0.2,True),588EqualizedConv2d(in\_features,out\_features,kernel\_size=3,padding=1),589nn.LeakyReLU(0.2,True),590)

#

Down-sampling layer

593self.down\_sample=DownSample()

#

Scaling factor 2​1​ after adding the residual

596self.scale=1/math.sqrt(2)

#

598defforward(self,x):

#

Get the residual connection

600residual=self.residual(x)

#

Convolutions

603x=self.block(x)

#

Down-sample

605x=self.down\_sample(x)

#

Add the residual and scale

608return(x+residual)\*self.scale

#

Mini-batch Standard Deviation

Mini-batch standard deviation calculates the standard deviation across a mini-batch (or a subgroups within the mini-batch) for each feature in the feature map. Then it takes the mean of all the standard deviations and appends it to the feature map as one extra feature.

611classMiniBatchStdDev(nn.Module):

#

  • group_size is the number of samples to calculate standard deviation across.
623def\_\_init\_\_(self,group\_size:int=4):

#

627super().\_\_init\_\_()628self.group\_size=group\_size

#

  • x is the feature map
630defforward(self,x:torch.Tensor):

#

Check if the batch size is divisible by the group size

635assertx.shape[0]%self.group\_size==0

#

Split the samples into groups of group_size , we flatten the feature map to a single dimension since we want to calculate the standard deviation for each feature.

638grouped=x.view(self.group\_size,-1)

#

Calculate the standard deviation for each feature among group_size samples

μi​σi​​=N1​g∑​xg,i​=N1​g∑​(xg,i​−μi​)2+ϵ​​

645std=torch.sqrt(grouped.var(dim=0)+1e-8)

#

Get the mean standard deviation

647std=std.mean().view(1,1,1,1)

#

Expand the standard deviation to append to the feature map

649b,\_,h,w=x.shape650std=std.expand(b,-1,h,w)

#

Append (concatenate) the standard deviations to the feature map

652returntorch.cat([x,std],dim=1)

#

Down-sample

The down-sample operation smoothens each feature channel and scale 2× using bilinear interpolation. This is based on the paper Making Convolutional Networks Shift-Invariant Again.

655classDownSample(nn.Module):

#

667def\_\_init\_\_(self):668super().\_\_init\_\_()

#

Smoothing layer

670self.smooth=Smooth()

#

672defforward(self,x:torch.Tensor):

#

Smoothing or blurring

674x=self.smooth(x)

#

Scaled down

676returnF.interpolate(x,(x.shape[2]//2,x.shape[3]//2),mode='bilinear',align\_corners=False)

#

Up-sample

The up-sample operation scales the image up by 2× and smoothens each feature channel. This is based on the paper Making Convolutional Networks Shift-Invariant Again.

679classUpSample(nn.Module):

#

690def\_\_init\_\_(self):691super().\_\_init\_\_()

#

Up-sampling layer

693self.up\_sample=nn.Upsample(scale\_factor=2,mode='bilinear',align\_corners=False)

#

Smoothing layer

695self.smooth=Smooth()

#

697defforward(self,x:torch.Tensor):

#

Up-sample and smoothen

699returnself.smooth(self.up\_sample(x))

#

Smoothing Layer

This layer blurs each channel

702classSmooth(nn.Module):

#

711def\_\_init\_\_(self):712super().\_\_init\_\_()

#

Blurring kernel

714kernel=[[1,2,1],715[2,4,2],716[1,2,1]]

#

Convert the kernel to a PyTorch tensor

718kernel=torch.tensor([[kernel]],dtype=torch.float)

#

Normalize the kernel

720kernel/=kernel.sum()

#

Save kernel as a fixed parameter (no gradient updates)

722self.kernel=nn.Parameter(kernel,requires\_grad=False)

#

Padding layer

724self.pad=nn.ReplicationPad2d(1)

#

726defforward(self,x:torch.Tensor):

#

Get shape of the input feature map

728b,c,h,w=x.shape

#

Reshape for smoothening

730x=x.view(-1,1,h,w)

#

Add padding

733x=self.pad(x)

#

Smoothen (blur) with the kernel

736x=F.conv2d(x,self.kernel)

#

Reshape and return

739returnx.view(b,c,h,w)

#

Learning-rate Equalized Linear Layer

This uses learning-rate equalized weights for a linear layer.

742classEqualizedLinear(nn.Module):

#

  • in_features is the number of features in the input feature map
  • out_features is the number of features in the output feature map
  • bias is the bias initialization constant
751def\_\_init\_\_(self,in\_features:int,out\_features:int,bias:float=0.):

#

758super().\_\_init\_\_()

#

Learning-rate equalized weights

760self.weight=EqualizedWeight([out\_features,in\_features])

#

Bias

762self.bias=nn.Parameter(torch.ones(out\_features)\*bias)

#

764defforward(self,x:torch.Tensor):

#

Linear transformation

766returnF.linear(x,self.weight(),bias=self.bias)

#

Learning-rate Equalized 2D Convolution Layer

This uses learning-rate equalized weights for a convolution layer.

769classEqualizedConv2d(nn.Module):

#

  • in_features is the number of features in the input feature map
  • out_features is the number of features in the output feature map
  • kernel_size is the size of the convolution kernel
  • padding is the padding to be added on both sides of each size dimension
778def\_\_init\_\_(self,in\_features:int,out\_features:int,779kernel\_size:int,padding:int=0):

#

786super().\_\_init\_\_()

#

Padding size

788self.padding=padding

#

Learning-rate equalized weights

790self.weight=EqualizedWeight([out\_features,in\_features,kernel\_size,kernel\_size])

#

Bias

792self.bias=nn.Parameter(torch.ones(out\_features))

#

794defforward(self,x:torch.Tensor):

#

Convolution

796returnF.conv2d(x,self.weight(),bias=self.bias,padding=self.padding)

#

Learning-rate Equalized Weights Parameter

This is based on equalized learning rate introduced in the Progressive GAN paper. Instead of initializing weights at N(0,c) they initialize weights to N(0,1) and then multiply them by c when using it. wi​=cw^i​

The gradients on stored parameters w^ get multiplied by c but this doesn't have an affect since optimizers such as Adam normalize them by a running mean of the squared gradients.

The optimizer updates on w^ are proportionate to the learning rate λ. But the effective weights w get updated proportionately to cλ. Without equalized learning rate, the effective weights will get updated proportionately to just λ.

So we are effectively scaling the learning rate by c for these weight parameters.

799classEqualizedWeight(nn.Module):

#

  • shape is the shape of the weight parameter
820def\_\_init\_\_(self,shape:List[int]):

#

824super().\_\_init\_\_()

#

He initialization constant

827self.c=1/math.sqrt(np.prod(shape[1:]))

#

Initialize the weights with N(0,1)

829self.weight=nn.Parameter(torch.randn(shape))

#

Weight multiplication coefficient

#

832defforward(self):

#

Multiply the weights by c and return

834returnself.weight\*self.c

#

Gradient Penalty

This is the R1​ regularization penality from the paper Which Training Methods for GANs do actually Converge?.

R1​(ψ)=2γ​EpD​(x)​[∥∇x​Dψ​(x)2∥]

That is we try to reduce the L2 norm of gradients of the discriminator with respect to images, for real images (PD​).

837classGradientPenalty(nn.Module):

#

  • x is x∼D
  • d is D(x)
853defforward(self,x:torch.Tensor,d:torch.Tensor):

#

Get batch size

860batch\_size=x.shape[0]

#

Calculate gradients of D(x) with respect to x. grad_outputs is set to 1 since we want the gradients of D(x), and we need to create and retain graph since we have to compute gradients with respect to weight on this loss.

866gradients,\*\_=torch.autograd.grad(outputs=d,867inputs=x,868grad\_outputs=d.new\_ones(d.shape),869create\_graph=True)

#

Reshape gradients to calculate the norm

872gradients=gradients.reshape(batch\_size,-1)

#

Calculate the norm ∥∇x​D(x)2∥

874norm=gradients.norm(2,dim=-1)

#

Return the loss ∥∇x​Dψ​(x)2∥

876returntorch.mean(norm\*\*2)

#

Path Length Penalty

This regularization encourages a fixed-size step in w to result in a fixed-magnitude change in the image.

Ew∼f(z),y∼N(0,I)​(∥Jw⊤​y∥2​−a)2

where Jw​ is the Jacobian Jw​=∂w∂g​, w are sampled from w∈W from the mapping network, and y are images with noise N(0,I).

a is the exponential moving average of ∥Jw⊤​y∥2​ as the training progresses.

Jw⊤​y is calculated without explicitly calculating the Jacobian using Jw⊤​y=∇w​(g(w)⋅y)

879classPathLengthPenalty(nn.Module):

#

  • beta is the constant β used to calculate the exponential moving average a
903def\_\_init\_\_(self,beta:float):

#

907super().\_\_init\_\_()

#

β

910self.beta=beta

#

Number of steps calculated N

912self.steps=nn.Parameter(torch.tensor(0.),requires\_grad=False)

#

Exponential sum of Jw⊤​y i=1∑N​β(N−i)[Jw⊤​y]i​ where [Jw⊤​y]i​ is the value of it at i-th step of training

916self.exp\_sum\_a=nn.Parameter(torch.tensor(0.),requires\_grad=False)

#

  • w is the batch of w of shape [batch_size, d_latent]
  • x are the generated images of shape [batch_size, 3, height, width]
918defforward(self,w:torch.Tensor,x:torch.Tensor):

#

Get the device

925device=x.device

#

Get number of pixels

927image\_size=x.shape[2]\*x.shape[3]

#

Calculate y∈N(0,I)

929y=torch.randn(x.shape,device=device)

#

Calculate (g(w)⋅y) and normalize by the square root of image size. This is scaling is not mentioned in the paper but was present in their implementation.

933output=(x\*y).sum()/math.sqrt(image\_size)

#

Calculate gradients to get Jw⊤​y

936gradients,\*\_=torch.autograd.grad(outputs=output,937inputs=w,938grad\_outputs=torch.ones(output.shape,device=device),939create\_graph=True)

#

Calculate L2-norm of Jw⊤​y

942norm=(gradients\*\*2).sum(dim=2).mean(dim=1).sqrt()

#

Regularize after first step

945ifself.steps\>0:

#

Calculate a 1−βN1​i=1∑N​β(N−i)[Jw⊤​y]i​

948a=self.exp\_sum\_a/(1-self.beta\*\*self.steps)

#

Calculate the penalty Ew∼f(z),y∼N(0,I)​(∥Jw⊤​y∥2​−a)2

952loss=torch.mean((norm-a)\*\*2)953else:

#

Return a dummy loss if we can't calculate a

955loss=norm.new\_tensor(0)

#

Calculate the mean of ∥Jw⊤​y∥2​

958mean=norm.mean().detach()

#

Update exponential sum

960self.exp\_sum\_a.mul\_(self.beta).add\_(mean,alpha=1-self.beta)

#

Increment N

962self.steps.add\_(1.)

#

Return the penalty

965returnloss

labml.ai