Back to Annotated Deep Learning Paper Implementations

Gradient Penalty for Wasserstein GAN (WGAN-GP)

docs/gan/wasserstein/gradient_penalty/index.html

latest2.4 KB
Original Source

homeganwassersteingradient_penalty

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/gan/wasserstein/gradient_penalty/ init.py)

#

Gradient Penalty for Wasserstein GAN (WGAN-GP)

This is an implementation of Improved Training of Wasserstein GANs.

WGAN suggests clipping weights to enforce Lipschitz constraint on the discriminator network (critic). This and other weight constraints like L2 norm clipping, weight normalization, L1, L2 weight decay have problems:

  1. Limiting the capacity of the discriminator 2. Exploding and vanishing gradients (without Batch Normalization).

The paper Improved Training of Wasserstein GANs proposal a better way to improve Lipschitz constraint, a gradient penalty.

LGP​=λx^∼Px^​E​[(∥∇x^​D(x^)∥2​−1)2]

where λ is the penalty weight and

xzϵxx^​∼Pr​∼p(z)∼U[0,1]←Gθ​(z)←ϵx+(1−ϵ)x

That is we try to keep the gradient norm ∥∇x^​D(x^)∥2​ close to 1.

In this implementation we set ϵ=1.

Here is the code for an experiment that uses gradient penalty.

46importtorch47importtorch.autograd4849fromtorchimportnn

#

Gradient Penalty

52classGradientPenalty(nn.Module):

#

  • x is x∼Pr​
  • f is D(x)

x^←x since we set ϵ=1 for this implementation.

57defforward(self,x:torch.Tensor,f:torch.Tensor):

#

Get batch size

67batch\_size=x.shape[0]

#

Calculate gradients of D(x) with respect to x. grad_outputs is set to ones since we want the gradients of D(x), and we need to create and retain graph since we have to compute gradients with respect to weight on this loss.

73gradients,\*\_=torch.autograd.grad(outputs=f,74inputs=x,75grad\_outputs=f.new\_ones(f.shape),76create\_graph=True)

#

Reshape gradients to calculate the norm

79gradients=gradients.reshape(batch\_size,-1)

#

Calculate the norm ∥∇x^​D(x^)∥2​

81norm=gradients.norm(2,dim=-1)

#

Return the loss (∥∇x^​D(x^)∥2​−1)2

83returntorch.mean((norm-1)\*\*2)

labml.ai