Back to Annotated Deep Learning Paper Implementations

Wasserstein GAN (WGAN)

docs/gan/wasserstein/index.html

latest3.5 KB
Original Source

homeganwasserstein

[View code on Github](https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/gan/wasserstein/ init.py)

#

Wasserstein GAN (WGAN)

This is an implementation of Wasserstein GAN.

The original GAN loss is based on Jensen-Shannon (JS) divergence between the real distribution Pr​ and generated distribution Pg​. The Wasserstein GAN is based on Earth Mover distance between these distributions.

W(Pr​,Pg​)=γ∈Π(Pr​,Pg​)inf​E(x,y)∼γ​∥x−y∥

Π(Pr​,Pg​) is the set of all joint distributions, whose marginal probabilities are γ(x,y).

E(x,y)∼γ​∥x−y∥ is the earth mover distance for a given joint distribution (x and y are probabilities).

So W(Pr​,Pg​) is equal to the least earth mover distance for any joint distribution between the real distribution Pr​ and generated distribution Pg​.

The paper shows that Jensen-Shannon (JS) divergence and other measures for the difference between two probability distributions are not smooth. And therefore if we are doing gradient descent on one of the probability distributions (parameterized) it will not converge.

Based on Kantorovich-Rubinstein duality, W(Pr​,Pg​)=∥f∥L​≤1sup​Ex∼Pr​​[f(x)]−Ex∼Pg​​[f(x)]

where ∥f∥L​≤1 are all 1-Lipschitz functions.

That is, it is equal to the greatest difference Ex∼Pr​​[f(x)]−Ex∼Pg​​[f(x)] among all 1-Lipschitz functions.

For K-Lipschitz functions, W(Pr​,Pg​)=∥f∥L​≤Ksup​Ex∼Pr​​[K1​f(x)]−Ex∼Pg​​[K1​f(x)]

If all K-Lipschitz functions can be represented as fw​ where f is parameterized by w∈W,

K⋅W(Pr​,Pg​)=w∈Wmax​Ex∼Pr​​[fw​(x)]−Ex∼Pg​​[fw​(x)]

If (Pg​) is represented by a generator gθ​(z) and z is from a known distribution z∼p(z),

K⋅W(Pr​,Pθ​)=w∈Wmax​Ex∼Pr​​[fw​(x)]−Ez∼p(z)​[fw​(gθ​(z))]

Now to converge gθ​ with Pr​ we can gradient descent on θ to minimize above formula.

Similarly we can find maxw∈W​ by ascending on w, while keeping K bounded. One way to keep K bounded is to clip all weights in the neural network that defines f clipped within a range.

Here is the code to try this on a simple MNIST generation experiment.

87importtorch.utils.data88fromtorchimportnn89fromtorch.nnimportfunctionalasF

#

Discriminator Loss

We want to find w to maximize Ex∼Pr​​[fw​(x)]−Ez∼p(z)​[fw​(gθ​(z))], so we minimize, −m1​i=1∑m​fw​(x(i))+m1​i=1∑m​fw​(gθ​(z(i)))

92classDiscriminatorLoss(nn.Module):

#

  • f_real is fw​(x)
  • f_fake is fw​(gθ​(z))

This returns the a tuple with losses for fw​(x) and fw​(gθ​(z)), which are later added. They are kept separate for logging.

103defforward(self,f\_real:torch.Tensor,f\_fake:torch.Tensor):

#

We use ReLUs to clip the loss to keep f∈[−1,+1] range.

114returnF.relu(1-f\_real).mean(),F.relu(1+f\_fake).mean()

#

Generator Loss

We want to find θ to minimize Ex∼Pr​​[fw​(x)]−Ez∼p(z)​[fw​(gθ​(z))] The first component is independent of θ, so we minimize, −m1​i=1∑m​fw​(gθ​(z(i)))

117classGeneratorLoss(nn.Module):

#

  • f_fake is fw​(gθ​(z))
129defforward(self,f\_fake:torch.Tensor):

#

133return-f\_fake.mean()

labml.ai