docs/gan/stylegan/index.html
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/gan/stylegan/ init.py)
This is a PyTorch implementation of the paper Analyzing and Improving the Image Quality of StyleGAN which introduces StyleGAN 2. StyleGAN 2 is an improvement over StyleGAN from the paper A Style-Based Generator Architecture for Generative Adversarial Networks. And StyleGAN is based on Progressive GAN from the paper Progressive Growing of GANs for Improved Quality, Stability, and Variation. All three papers are from the same authors from NVIDIA AI.
Our implementation is a minimalistic StyleGAN 2 model training code. Only single GPU training is supported to keep the implementation simple. We managed to shrink it to keep it at less than 500 lines of code, including the training loop.
🏃 Here's the training code: experiment.py.
These are 64×64 images generated after training for about 80K steps.
We'll first introduce the three papers at a high level.
Generative adversarial networks have two components; the generator and the discriminator. The generator network takes a random latent vector (z∈Z) and tries to generate a realistic image. The discriminator network tries to differentiate the real images from generated images. When we train the two networks together the generator starts generating images indistinguishable from real images.
Progressive GAN generates high-resolution images (1080×1080) of size. It does so by progressively increasing the image size. First, it trains a network that produces a 4×4 image, then 8×8 , then an 16×16 image, and so on up to the desired image resolution.
At each resolution, the generator network produces an image in latent space which is converted into RGB, with a 1×1 convolution. When we progress from a lower resolution to a higher resolution (say from 4×4 to 8×8 ) we scale the latent image by 2× and add a new block (two 3×3 convolution layers) and a new 1×1 layer to get RGB. The transition is done smoothly by adding a residual connection to the 2× scaled 4×4 RGB image. The weight of this residual connection is slowly reduced, to let the new block take over.
The discriminator is a mirror image of the generator network. The progressive growth of the discriminator is done similarly.
2× and 0.5× denote feature map resolution scaling and scaling. 4×4, 8×4, ... denote feature map resolution at the generator or discriminator block. Each discriminator and generator block consists of 2 convolution layers with leaky ReLU activations.
They use minibatch standard deviation to increase variation and equalized learning rate which we discussed below in the implementation. They also use pixel-wise normalization where at each pixel the feature vector is normalized. They apply this to all the convolution layer outputs (except RGB).
StyleGAN improves the generator of Progressive GAN keeping the discriminator architecture the same.
It maps the random latent vector (z∈Z) into a different latent space (w∈W), with an 8-layer neural network. This gives an intermediate latent space W where the factors of variations are more linear (disentangled).
Then w is transformed into two vectors ( styles ) per layer, i, yi=(ys,i,yb,i)=fAi(w) and used for scaling and shifting (biasing) in each layer with AdaIN operator (normalize and scale): AdaIN(xi,yi)=ys,iσ(xi)xi−μ(xi)+yb,i
To prevent the generator from assuming adjacent styles are correlated, they randomly use different styles for different blocks. That is, they sample two latent vectors (z1,z2) and corresponding (w1,w2) and use w1 based styles for some blocks and w2 based styles for some blacks randomly.
Noise is made available to each block which helps the generator create more realistic images. Noise is scaled per channel by a learned weight.
All the up and down-sampling operations are accompanied by bilinear smoothing.
A denotes a linear layer. B denotes a broadcast and scaling operation (noise is a single channel). StyleGAN also uses progressive growing like Progressive GAN.
StyleGAN 2 changes both the generator and the discriminator of StyleGAN.
They remove the AdaIN operator and replace it with the weight modulation and demodulation step. This is supposed to improve what they call droplet artifacts that are present in generated images, which are caused by the normalization in AdaIN operator. Style vector per layer is calculated from wi∈W as si=fAi(wi).
Then the convolution weights w are modulated as follows. (w here on refers to weights not intermediate latent space, we are sticking to the same notation as the paper.)
wi,j,k′=si⋅wi,j,k Then it's demodulated by normalizing, wi,j,k′′=∑i,kwi,j,k′2+ϵwi,j,k′ where i is the input channel, j is the output channel, and k is the kernel index.
Path length regularization encourages a fixed-size step in W to result in a non-zero, fixed-magnitude change in the generated image.
StyleGAN2 uses residual connections (with down-sampling) in the discriminator and skip connections in the generator with up-sampling (the RGB outputs from each layer are added - no residual connections in feature maps). They show that with experiments that the contribution of low-resolution layers is higher at beginning of the training and then high-resolution layers take over.
148importmath149fromtypingimportTuple,Optional,List150151importnumpyasnp152importtorch153importtorch.nn.functionalasF154importtorch.utils.data155fromtorchimportnn
This is an MLP with 8 linear layers. The mapping network maps the latent vector z∈W to an intermediate latent space w∈W. W space will be disentangled from the image space where the factors of variation become more linear.
158classMappingNetwork(nn.Module):
features is the number of features in z and wn_layers is the number of layers in the mapping network.173def\_\_init\_\_(self,features:int,n\_layers:int):
178super().\_\_init\_\_()
Create the MLP
181layers=[]182foriinrange(n\_layers):
Equalized learning-rate linear layers
184layers.append(EqualizedLinear(features,features))
Leaky Relu
186layers.append(nn.LeakyReLU(negative\_slope=0.2,inplace=True))187188self.net=nn.Sequential(\*layers)
190defforward(self,z:torch.Tensor):
Normalize z
192z=F.normalize(z,dim=1)
Map z to w
194returnself.net(z)
A denotes a linear layer. B denotes a broadcast and scaling operation (noise is a single channel). toRGB also has a style modulation which is not shown in the diagram to keep it simple.
The generator starts with a learned constant. Then it has a series of blocks. The feature map resolution is doubled at each block Each block outputs an RGB image and they are scaled up and summed to get the final RGB image.
197classGenerator(nn.Module):
log_resolution is the log2 of image resolutiond_latent is the dimensionality of wn_features number of features in the convolution layer at the highest resolution (final block)max_features maximum number of features in any generator block214def\_\_init\_\_(self,log\_resolution:int,d\_latent:int,n\_features:int=32,max\_features:int=512):
221super().\_\_init\_\_()
Calculate the number of features for each block
Something like [512, 512, 256, 128, 64, 32]
226features=[min(max\_features,n\_features\*(2\*\*i))foriinrange(log\_resolution-2,-1,-1)]
Number of generator blocks
228self.n\_blocks=len(features)
Trainable 4×4 constant
231self.initial\_constant=nn.Parameter(torch.randn((1,features[0],4,4)))
First style block for 4×4 resolution and layer to get RGB
234self.style\_block=StyleBlock(d\_latent,features[0],features[0])235self.to\_rgb=ToRGB(d\_latent,features[0])
Generator blocks
238blocks=[GeneratorBlock(d\_latent,features[i-1],features[i])foriinrange(1,self.n\_blocks)]239self.blocks=nn.ModuleList(blocks)
2× up sampling layer. The feature space is up sampled at each block
243self.up\_sample=UpSample()
w is w. In order to mix-styles (use different w for different layers), we provide a separate w for each generator block. It has shape [n_blocks, batch_size, d_latent] .input_noise is the noise for each block. It's a list of pairs of noise sensors because each block (except the initial) has two noise inputs after each convolution layer (see the diagram).245defforward(self,w:torch.Tensor,input\_noise:List[Tuple[Optional[torch.Tensor],Optional[torch.Tensor]]]):
Get batch size
255batch\_size=w.shape[1]
Expand the learned constant to match batch size
258x=self.initial\_constant.expand(batch\_size,-1,-1,-1)
The first style block
261x=self.style\_block(x,w[0],input\_noise[0][1])
Get first rgb image
263rgb=self.to\_rgb(x,w[0])
Evaluate rest of the blocks
266foriinrange(1,self.n\_blocks):
Up sample the feature map
268x=self.up\_sample(x)
Run it through the generator block
270x,rgb\_new=self.blocks[i-1](x,w[i],input\_noise[i])
Up sample the RGB image and add to the rgb from the block
272rgb=self.up\_sample(rgb)+rgb\_new
Return the final RGB image
275returnrgb
A denotes a linear layer. B denotes a broadcast and scaling operation (noise is a single channel). toRGB also has a style modulation which is not shown in the diagram to keep it simple.
The generator block consists of two style blocks (3×3 convolutions with style modulation) and an RGB output.
278classGeneratorBlock(nn.Module):
d_latent is the dimensionality of win_features is the number of features in the input feature mapout_features is the number of features in the output feature map294def\_\_init\_\_(self,d\_latent:int,in\_features:int,out\_features:int):
300super().\_\_init\_\_()
First style block changes the feature map size to out_features
303self.style\_block1=StyleBlock(d\_latent,in\_features,out\_features)
Second style block
305self.style\_block2=StyleBlock(d\_latent,out\_features,out\_features)
toRGB layer
308self.to\_rgb=ToRGB(d\_latent,out\_features)
x is the input feature map of shape [batch_size, in_features, height, width]w is w with shape [batch_size, d_latent]noise is a tuple of two noise tensors of shape [batch_size, 1, height, width]310defforward(self,x:torch.Tensor,w:torch.Tensor,noise:Tuple[Optional[torch.Tensor],Optional[torch.Tensor]]):
First style block with first noise tensor. The output is of shape [batch_size, out_features, height, width]
318x=self.style\_block1(x,w,noise[0])
Second style block with second noise tensor. The output is of shape [batch_size, out_features, height, width]
321x=self.style\_block2(x,w,noise[1])
Get RGB image
324rgb=self.to\_rgb(x,w)
Return feature map and rgb image
327returnx,rgb
A denotes a linear layer. B denotes a broadcast and scaling operation (noise is single channel).
Style block has a weight modulation convolution layer.
330classStyleBlock(nn.Module):
d_latent is the dimensionality of win_features is the number of features in the input feature mapout_features is the number of features in the output feature map344def\_\_init\_\_(self,d\_latent:int,in\_features:int,out\_features:int):
350super().\_\_init\_\_()
Get style vector from w (denoted by A in the diagram) with an equalized learning-rate linear layer
353self.to\_style=EqualizedLinear(d\_latent,in\_features,bias=1.0)
Weight modulated convolution layer
355self.conv=Conv2dWeightModulate(in\_features,out\_features,kernel\_size=3)
Noise scale
357self.scale\_noise=nn.Parameter(torch.zeros(1))
Bias
359self.bias=nn.Parameter(torch.zeros(out\_features))
Activation function
362self.activation=nn.LeakyReLU(0.2,True)
x is the input feature map of shape [batch_size, in_features, height, width]w is w with shape [batch_size, d_latent]noise is a tensor of shape [batch_size, 1, height, width]364defforward(self,x:torch.Tensor,w:torch.Tensor,noise:Optional[torch.Tensor]):
Get style vector s
371s=self.to\_style(w)
Weight modulated convolution
373x=self.conv(x,s)
Scale and add noise
375ifnoiseisnotNone:376x=x+self.scale\_noise[None,:,None,None]\*noise
Add bias and evaluate activation function
378returnself.activation(x+self.bias[None,:,None,None])
A denotes a linear layer.
Generates an RGB image from a feature map using 1×1 convolution.
381classToRGB(nn.Module):
d_latent is the dimensionality of wfeatures is the number of features in the feature map394def\_\_init\_\_(self,d\_latent:int,features:int):
399super().\_\_init\_\_()
Get style vector from w (denoted by A in the diagram) with an equalized learning-rate linear layer
402self.to\_style=EqualizedLinear(d\_latent,features,bias=1.0)
Weight modulated convolution layer without demodulation
405self.conv=Conv2dWeightModulate(features,3,kernel\_size=1,demodulate=False)
Bias
407self.bias=nn.Parameter(torch.zeros(3))
Activation function
409self.activation=nn.LeakyReLU(0.2,True)
x is the input feature map of shape [batch_size, in_features, height, width]w is w with shape [batch_size, d_latent]411defforward(self,x:torch.Tensor,w:torch.Tensor):
Get style vector s
417style=self.to\_style(w)
Weight modulated convolution
419x=self.conv(x,style)
Add bias and evaluate activation function
421returnself.activation(x+self.bias[None,:,None,None])
This layer scales the convolution weights by the style vector and demodulates by normalizing it.
424classConv2dWeightModulate(nn.Module):
in_features is the number of features in the input feature mapout_features is the number of features in the output feature mapkernel_size is the size of the convolution kerneldemodulate is flag whether to normalize weights by its standard deviationeps is the ϵ for normalizing431def\_\_init\_\_(self,in\_features:int,out\_features:int,kernel\_size:int,432demodulate:float=True,eps:float=1e-8):
440super().\_\_init\_\_()
Number of output features
442self.out\_features=out\_features
Whether to normalize weights
444self.demodulate=demodulate
Padding size
446self.padding=(kernel\_size-1)//2
Weights parameter with equalized learning rate
449self.weight=EqualizedWeight([out\_features,in\_features,kernel\_size,kernel\_size])
ϵ
451self.eps=eps
x is the input feature map of shape [batch_size, in_features, height, width]s is style based scaling tensor of shape [batch_size, in_features]453defforward(self,x:torch.Tensor,s:torch.Tensor):
Get batch size, height and width
460b,\_,h,w=x.shape
Reshape the scales
463s=s[:,None,:,None,None]
Get learning rate equalized weights
465weights=self.weight()[None,:,:,:,:]
w‘i,j,k=si∗wi,j,k where i is the input channel, j is the output channel, and k is the kernel index.
The result has shape [batch_size, out_features, in_features, kernel_size, kernel_size]
470weights=weights\*s
Demodulate
473ifself.demodulate:
σj=i,k∑(wi,j,k′)2+ϵ
475sigma\_inv=torch.rsqrt((weights\*\*2).sum(dim=(2,3,4),keepdim=True)+self.eps)
wi,j,k′′=∑i,k(wi,j,k′)2+ϵwi,j,k′
477weights=weights\*sigma\_inv
Reshape x
480x=x.reshape(1,-1,h,w)
Reshape weights
483\_,\_,\*ws=weights.shape484weights=weights.reshape(b\*self.out\_features,\*ws)
Use grouped convolution to efficiently calculate the convolution with sample wise kernel. i.e. we have a different kernel (weights) for each sample in the batch
488x=F.conv2d(x,weights,padding=self.padding,groups=b)
Reshape x to [batch_size, out_features, height, width] and return
491returnx.reshape(-1,self.out\_features,h,w)
Discriminator first transforms the image to a feature map of the same resolution and then runs it through a series of blocks with residual connections. The resolution is down-sampled by 2× at each block while doubling the number of features.
494classDiscriminator(nn.Module):
log_resolution is the log2 of image resolutionn_features number of features in the convolution layer at the highest resolution (first block)max_features maximum number of features in any generator block508def\_\_init\_\_(self,log\_resolution:int,n\_features:int=64,max\_features:int=512):
514super().\_\_init\_\_()
Layer to convert RGB image to a feature map with n_features number of features.
517self.from\_rgb=nn.Sequential(518EqualizedConv2d(3,n\_features,1),519nn.LeakyReLU(0.2,True),520)
Calculate the number of features for each block.
Something like [64, 128, 256, 512, 512, 512] .
525features=[min(max\_features,n\_features\*(2\*\*i))foriinrange(log\_resolution-1)]
Number of discirminator blocks
527n\_blocks=len(features)-1
Discriminator blocks
529blocks=[DiscriminatorBlock(features[i],features[i+1])foriinrange(n\_blocks)]530self.blocks=nn.Sequential(\*blocks)
533self.std\_dev=MiniBatchStdDev()
Number of features after adding the standard deviations map
535final\_features=features[-1]+1
Final 3×3 convolution layer
537self.conv=EqualizedConv2d(final\_features,final\_features,3)
Final linear layer to get the classification
539self.final=EqualizedLinear(2\*2\*final\_features,1)
x is the input image of shape [batch_size, 3, height, width]541defforward(self,x:torch.Tensor):
Try to normalize the image (this is totally optional, but sped up the early training a little)
547x=x-0.5
Convert from RGB
549x=self.from\_rgb(x)
Run through the discriminator blocks
551x=self.blocks(x)
Calculate and append mini-batch standard deviation
554x=self.std\_dev(x)
3×3 convolution
556x=self.conv(x)
Flatten
558x=x.reshape(x.shape[0],-1)
Return the classification score
560returnself.final(x)
Discriminator block consists of two 3×3 convolutions with a residual connection.
563classDiscriminatorBlock(nn.Module):
in_features is the number of features in the input feature mapout_features is the number of features in the output feature map574def\_\_init\_\_(self,in\_features,out\_features):
579super().\_\_init\_\_()
Down-sampling and 1×1 convolution layer for the residual connection
581self.residual=nn.Sequential(DownSample(),582EqualizedConv2d(in\_features,out\_features,kernel\_size=1))
Two 3×3 convolutions
585self.block=nn.Sequential(586EqualizedConv2d(in\_features,in\_features,kernel\_size=3,padding=1),587nn.LeakyReLU(0.2,True),588EqualizedConv2d(in\_features,out\_features,kernel\_size=3,padding=1),589nn.LeakyReLU(0.2,True),590)
Down-sampling layer
593self.down\_sample=DownSample()
Scaling factor 21 after adding the residual
596self.scale=1/math.sqrt(2)
598defforward(self,x):
Get the residual connection
600residual=self.residual(x)
Convolutions
603x=self.block(x)
Down-sample
605x=self.down\_sample(x)
Add the residual and scale
608return(x+residual)\*self.scale
Mini-batch standard deviation calculates the standard deviation across a mini-batch (or a subgroups within the mini-batch) for each feature in the feature map. Then it takes the mean of all the standard deviations and appends it to the feature map as one extra feature.
611classMiniBatchStdDev(nn.Module):
group_size is the number of samples to calculate standard deviation across.623def\_\_init\_\_(self,group\_size:int=4):
627super().\_\_init\_\_()628self.group\_size=group\_size
x is the feature map630defforward(self,x:torch.Tensor):
Check if the batch size is divisible by the group size
635assertx.shape[0]%self.group\_size==0
Split the samples into groups of group_size , we flatten the feature map to a single dimension since we want to calculate the standard deviation for each feature.
638grouped=x.view(self.group\_size,-1)
Calculate the standard deviation for each feature among group_size samples
μiσi=N1g∑xg,i=N1g∑(xg,i−μi)2+ϵ
645std=torch.sqrt(grouped.var(dim=0)+1e-8)
Get the mean standard deviation
647std=std.mean().view(1,1,1,1)
Expand the standard deviation to append to the feature map
649b,\_,h,w=x.shape650std=std.expand(b,-1,h,w)
Append (concatenate) the standard deviations to the feature map
652returntorch.cat([x,std],dim=1)
The down-sample operation smoothens each feature channel and scale 2× using bilinear interpolation. This is based on the paper Making Convolutional Networks Shift-Invariant Again.
655classDownSample(nn.Module):
667def\_\_init\_\_(self):668super().\_\_init\_\_()
Smoothing layer
670self.smooth=Smooth()
672defforward(self,x:torch.Tensor):
Smoothing or blurring
674x=self.smooth(x)
Scaled down
676returnF.interpolate(x,(x.shape[2]//2,x.shape[3]//2),mode='bilinear',align\_corners=False)
The up-sample operation scales the image up by 2× and smoothens each feature channel. This is based on the paper Making Convolutional Networks Shift-Invariant Again.
679classUpSample(nn.Module):
690def\_\_init\_\_(self):691super().\_\_init\_\_()
Up-sampling layer
693self.up\_sample=nn.Upsample(scale\_factor=2,mode='bilinear',align\_corners=False)
Smoothing layer
695self.smooth=Smooth()
697defforward(self,x:torch.Tensor):
Up-sample and smoothen
699returnself.smooth(self.up\_sample(x))
This layer blurs each channel
702classSmooth(nn.Module):
711def\_\_init\_\_(self):712super().\_\_init\_\_()
Blurring kernel
714kernel=[[1,2,1],715[2,4,2],716[1,2,1]]
Convert the kernel to a PyTorch tensor
718kernel=torch.tensor([[kernel]],dtype=torch.float)
Normalize the kernel
720kernel/=kernel.sum()
Save kernel as a fixed parameter (no gradient updates)
722self.kernel=nn.Parameter(kernel,requires\_grad=False)
Padding layer
724self.pad=nn.ReplicationPad2d(1)
726defforward(self,x:torch.Tensor):
Get shape of the input feature map
728b,c,h,w=x.shape
Reshape for smoothening
730x=x.view(-1,1,h,w)
Add padding
733x=self.pad(x)
Smoothen (blur) with the kernel
736x=F.conv2d(x,self.kernel)
Reshape and return
739returnx.view(b,c,h,w)
This uses learning-rate equalized weights for a linear layer.
742classEqualizedLinear(nn.Module):
in_features is the number of features in the input feature mapout_features is the number of features in the output feature mapbias is the bias initialization constant751def\_\_init\_\_(self,in\_features:int,out\_features:int,bias:float=0.):
758super().\_\_init\_\_()
Learning-rate equalized weights
760self.weight=EqualizedWeight([out\_features,in\_features])
Bias
762self.bias=nn.Parameter(torch.ones(out\_features)\*bias)
764defforward(self,x:torch.Tensor):
Linear transformation
766returnF.linear(x,self.weight(),bias=self.bias)
This uses learning-rate equalized weights for a convolution layer.
769classEqualizedConv2d(nn.Module):
in_features is the number of features in the input feature mapout_features is the number of features in the output feature mapkernel_size is the size of the convolution kernelpadding is the padding to be added on both sides of each size dimension778def\_\_init\_\_(self,in\_features:int,out\_features:int,779kernel\_size:int,padding:int=0):
786super().\_\_init\_\_()
Padding size
788self.padding=padding
Learning-rate equalized weights
790self.weight=EqualizedWeight([out\_features,in\_features,kernel\_size,kernel\_size])
Bias
792self.bias=nn.Parameter(torch.ones(out\_features))
794defforward(self,x:torch.Tensor):
Convolution
796returnF.conv2d(x,self.weight(),bias=self.bias,padding=self.padding)
This is based on equalized learning rate introduced in the Progressive GAN paper. Instead of initializing weights at N(0,c) they initialize weights to N(0,1) and then multiply them by c when using it. wi=cw^i
The gradients on stored parameters w^ get multiplied by c but this doesn't have an affect since optimizers such as Adam normalize them by a running mean of the squared gradients.
The optimizer updates on w^ are proportionate to the learning rate λ. But the effective weights w get updated proportionately to cλ. Without equalized learning rate, the effective weights will get updated proportionately to just λ.
So we are effectively scaling the learning rate by c for these weight parameters.
799classEqualizedWeight(nn.Module):
shape is the shape of the weight parameter820def\_\_init\_\_(self,shape:List[int]):
824super().\_\_init\_\_()
He initialization constant
827self.c=1/math.sqrt(np.prod(shape[1:]))
Initialize the weights with N(0,1)
829self.weight=nn.Parameter(torch.randn(shape))
Weight multiplication coefficient
832defforward(self):
Multiply the weights by c and return
834returnself.weight\*self.c
This is the R1 regularization penality from the paper Which Training Methods for GANs do actually Converge?.
R1(ψ)=2γEpD(x)[∥∇xDψ(x)2∥]
That is we try to reduce the L2 norm of gradients of the discriminator with respect to images, for real images (PD).
837classGradientPenalty(nn.Module):
x is x∼Dd is D(x)853defforward(self,x:torch.Tensor,d:torch.Tensor):
Get batch size
860batch\_size=x.shape[0]
Calculate gradients of D(x) with respect to x. grad_outputs is set to 1 since we want the gradients of D(x), and we need to create and retain graph since we have to compute gradients with respect to weight on this loss.
866gradients,\*\_=torch.autograd.grad(outputs=d,867inputs=x,868grad\_outputs=d.new\_ones(d.shape),869create\_graph=True)
Reshape gradients to calculate the norm
872gradients=gradients.reshape(batch\_size,-1)
Calculate the norm ∥∇xD(x)2∥
874norm=gradients.norm(2,dim=-1)
Return the loss ∥∇xDψ(x)2∥
876returntorch.mean(norm\*\*2)
This regularization encourages a fixed-size step in w to result in a fixed-magnitude change in the image.
Ew∼f(z),y∼N(0,I)(∥Jw⊤y∥2−a)2
where Jw is the Jacobian Jw=∂w∂g, w are sampled from w∈W from the mapping network, and y are images with noise N(0,I).
a is the exponential moving average of ∥Jw⊤y∥2 as the training progresses.
Jw⊤y is calculated without explicitly calculating the Jacobian using Jw⊤y=∇w(g(w)⋅y)
879classPathLengthPenalty(nn.Module):
beta is the constant β used to calculate the exponential moving average a903def\_\_init\_\_(self,beta:float):
907super().\_\_init\_\_()
β
910self.beta=beta
Number of steps calculated N
912self.steps=nn.Parameter(torch.tensor(0.),requires\_grad=False)
Exponential sum of Jw⊤y i=1∑Nβ(N−i)[Jw⊤y]i where [Jw⊤y]i is the value of it at i-th step of training
916self.exp\_sum\_a=nn.Parameter(torch.tensor(0.),requires\_grad=False)
w is the batch of w of shape [batch_size, d_latent]x are the generated images of shape [batch_size, 3, height, width]918defforward(self,w:torch.Tensor,x:torch.Tensor):
Get the device
925device=x.device
Get number of pixels
927image\_size=x.shape[2]\*x.shape[3]
Calculate y∈N(0,I)
929y=torch.randn(x.shape,device=device)
Calculate (g(w)⋅y) and normalize by the square root of image size. This is scaling is not mentioned in the paper but was present in their implementation.
933output=(x\*y).sum()/math.sqrt(image\_size)
Calculate gradients to get Jw⊤y
936gradients,\*\_=torch.autograd.grad(outputs=output,937inputs=w,938grad\_outputs=torch.ones(output.shape,device=device),939create\_graph=True)
Calculate L2-norm of Jw⊤y
942norm=(gradients\*\*2).sum(dim=2).mean(dim=1).sqrt()
Regularize after first step
945ifself.steps\>0:
Calculate a 1−βN1i=1∑Nβ(N−i)[Jw⊤y]i
948a=self.exp\_sum\_a/(1-self.beta\*\*self.steps)
Calculate the penalty Ew∼f(z),y∼N(0,I)(∥Jw⊤y∥2−a)2
952loss=torch.mean((norm-a)\*\*2)953else:
Return a dummy loss if we can't calculate a
955loss=norm.new\_tensor(0)
Calculate the mean of ∥Jw⊤y∥2
958mean=norm.mean().detach()
Update exponential sum
960self.exp\_sum\_a.mul\_(self.beta).add\_(mean,alpha=1-self.beta)
Increment N
962self.steps.add\_(1.)
Return the penalty
965returnloss