docs/gan/wasserstein/gradient_penalty/experiment.html
homeganwassersteingradient_penalty
10importtorch1112fromlabmlimportexperiment,tracker
Import configurations from Wasserstein experiment
14fromlabml\_nn.gan.wasserstein.experimentimportConfigsasOriginalConfigs
16fromlabml\_nn.gan.wasserstein.gradient\_penaltyimportGradientPenalty
We extend original GAN implementation and override the discriminator (critic) loss calculation to include gradient penalty.
19classConfigs(OriginalConfigs):
Gradient penalty coefficient λ
28gradient\_penalty\_coefficient:float=10.0
30gradient\_penalty=GradientPenalty()
This overrides the original discriminator loss calculation and includes gradient penalty.
32defcalc\_discriminator\_loss(self,data:torch.Tensor):
Require gradients on x to calculate gradient penalty
38data.requires\_grad\_()
Sample z∼p(z)
40latent=self.sample\_z(data.shape[0])
D(x)
42f\_real=self.discriminator(data)
D(Gθ(z))
44f\_fake=self.discriminator(self.generator(latent).detach())
Get discriminator losses
46loss\_true,loss\_false=self.discriminator\_loss(f\_real,f\_fake)
Calculate gradient penalties in training mode
48ifself.mode.is\_train:49gradient\_penalty=self.gradient\_penalty(data,f\_real)50tracker.add("loss.gp.",gradient\_penalty)51loss=loss\_true+loss\_false+self.gradient\_penalty\_coefficient\*gradient\_penalty
Skip gradient penalty otherwise
53else:54loss=loss\_true+loss\_false
Log stuff
57tracker.add("loss.discriminator.true.",loss\_true)58tracker.add("loss.discriminator.false.",loss\_false)59tracker.add("loss.discriminator.",loss)6061returnloss
64defmain():
Create configs object
66conf=Configs()
Create experiment
68experiment.create(name='mnist\_wassertein\_gp\_dcgan')
Override configurations
70experiment.configs(conf,71{72'discriminator':'cnn',73'generator':'cnn',74'label\_smoothing':0.01,75'generator\_loss':'wasserstein',76'discriminator\_loss':'wasserstein',77'discriminator\_k':5,78})
Start the experiment and run training loop
81withexperiment.start():82conf.run()838485if\_\_name\_\_=='\_\_main\_\_':86main()