docs/gan/wasserstein/gradient_penalty/index.html
homeganwassersteingradient_penalty
[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/gan/wasserstein/gradient_penalty/ init.py)
This is an implementation of Improved Training of Wasserstein GANs.
WGAN suggests clipping weights to enforce Lipschitz constraint on the discriminator network (critic). This and other weight constraints like L2 norm clipping, weight normalization, L1, L2 weight decay have problems:
The paper Improved Training of Wasserstein GANs proposal a better way to improve Lipschitz constraint, a gradient penalty.
LGP=λx^∼Px^E[(∥∇x^D(x^)∥2−1)2]
where λ is the penalty weight and
xzϵxx^∼Pr∼p(z)∼U[0,1]←Gθ(z)←ϵx+(1−ϵ)x
That is we try to keep the gradient norm ∥∇x^D(x^)∥2 close to 1.
In this implementation we set ϵ=1.
Here is the code for an experiment that uses gradient penalty.
46importtorch47importtorch.autograd4849fromtorchimportnn
52classGradientPenalty(nn.Module):
x is x∼Prf is D(x)x^←x since we set ϵ=1 for this implementation.
57defforward(self,x:torch.Tensor,f:torch.Tensor):
Get batch size
67batch\_size=x.shape[0]
Calculate gradients of D(x) with respect to x. grad_outputs is set to ones since we want the gradients of D(x), and we need to create and retain graph since we have to compute gradients with respect to weight on this loss.
73gradients,\*\_=torch.autograd.grad(outputs=f,74inputs=x,75grad\_outputs=f.new\_ones(f.shape),76create\_graph=True)
Reshape gradients to calculate the norm
79gradients=gradients.reshape(batch\_size,-1)
Calculate the norm ∥∇x^D(x^)∥2
81norm=gradients.norm(2,dim=-1)
Return the loss (∥∇x^D(x^)∥2−1)2
83returntorch.mean((norm-1)\*\*2)