Back to Annotated Deep Learning Paper Implementations

WGAN-GP experiment with MNIST

docs/gan/wasserstein/gradient_penalty/experiment.html

latest2.9 KB
Original Source

homeganwassersteingradient_penalty

View code on Github

#

WGAN-GP experiment with MNIST

10importtorch1112fromlabmlimportexperiment,tracker

#

Import configurations from Wasserstein experiment

14fromlabml\_nn.gan.wasserstein.experimentimportConfigsasOriginalConfigs

#

16fromlabml\_nn.gan.wasserstein.gradient\_penaltyimportGradientPenalty

#

Configuration class

We extend original GAN implementation and override the discriminator (critic) loss calculation to include gradient penalty.

19classConfigs(OriginalConfigs):

#

Gradient penalty coefficient λ

28gradient\_penalty\_coefficient:float=10.0

#

30gradient\_penalty=GradientPenalty()

#

This overrides the original discriminator loss calculation and includes gradient penalty.

32defcalc\_discriminator\_loss(self,data:torch.Tensor):

#

Require gradients on x to calculate gradient penalty

38data.requires\_grad\_()

#

Sample z∼p(z)

40latent=self.sample\_z(data.shape[0])

#

D(x)

42f\_real=self.discriminator(data)

#

D(Gθ​(z))

44f\_fake=self.discriminator(self.generator(latent).detach())

#

Get discriminator losses

46loss\_true,loss\_false=self.discriminator\_loss(f\_real,f\_fake)

#

Calculate gradient penalties in training mode

48ifself.mode.is\_train:49gradient\_penalty=self.gradient\_penalty(data,f\_real)50tracker.add("loss.gp.",gradient\_penalty)51loss=loss\_true+loss\_false+self.gradient\_penalty\_coefficient\*gradient\_penalty

#

Skip gradient penalty otherwise

53else:54loss=loss\_true+loss\_false

#

Log stuff

57tracker.add("loss.discriminator.true.",loss\_true)58tracker.add("loss.discriminator.false.",loss\_false)59tracker.add("loss.discriminator.",loss)6061returnloss

#

64defmain():

#

Create configs object

66conf=Configs()

#

Create experiment

68experiment.create(name='mnist\_wassertein\_gp\_dcgan')

#

Override configurations

70experiment.configs(conf,71{72'discriminator':'cnn',73'generator':'cnn',74'label\_smoothing':0.01,75'generator\_loss':'wasserstein',76'discriminator\_loss':'wasserstein',77'discriminator\_k':5,78})

#

Start the experiment and run training loop

81withexperiment.start():82conf.run()838485if\_\_name\_\_=='\_\_main\_\_':86main()

labml.ai