Back to Annotated Deep Learning Paper Implementations

MNISTසමඟ WGAN-GP අත්හදා බැලීම

docs/si/gan/wasserstein/gradient_penalty/experiment.html

latest4.0 KB
Original Source

homeganwassersteingradient_penalty

View code on Github

#

MNISTසමඟ WGAN-GP අත්හදා බැලීම

10importtorch1112fromlabmlimportexperiment,tracker

#

වොසර්ස්ටයින් අත්හදා බැලීමෙන් වින්යාසයන් ආනයනය කරන්න

14fromlabml\_nn.gan.wasserstein.experimentimportConfigsasOriginalConfigs

#

16fromlabml\_nn.gan.wasserstein.gradient\_penaltyimportGradientPenalty

#

වින්යාසපන්තිය

අපි මුල් GAN ක්රියාත්මක කිරීම දීර් extend කර වර්ගීකරණ ද penalty ුවම ඇතුළත් කිරීම සඳහා වෙනස්කම් කරන්නා (විචාරක) පාඩු ගණනය කිරීම අභිබවා යමු.

19classConfigs(OriginalConfigs):

#

ශ්රේණියේදණ්ඩන සංගුණකය λ

28gradient\_penalty\_coefficient:float=10.0

#

30gradient\_penalty=GradientPenalty()

#

මෙයමුල් වෙනස්කම් කරන්නාගේ අලාභය ගණනය කිරීම අභිබවා යන අතර ශ්රේණියේ ද penalty ුවම් ද ඇතුළත් වේ.

32defcalc\_discriminator\_loss(self,data:torch.Tensor):

#

ඵලයඅනුක්රමික දඬුවම ගණනය x කිරීමට මත ඵලය අනුක්රමික අවශ්ය

38data.requires\_grad\_()

#

නියැදිය z∼p(z)

40latent=self.sample\_z(data.shape[0])

#

D(x)

42f\_real=self.discriminator(data)

#

D(Gθ​(z))

44f\_fake=self.discriminator(self.generator(latent).detach())

#

වෙනස්කම්කරන්නාගේ පාඩු ලබා ගන්න

46loss\_true,loss\_false=self.discriminator\_loss(f\_real,f\_fake)

#

පුහුණුප්රකාරයේදී ශ්රේණියේ ද ties ුවම් ගණනය කරන්න

48ifself.mode.is\_train:49gradient\_penalty=self.gradient\_penalty(data,f\_real)50tracker.add("loss.gp.",gradient\_penalty)51loss=loss\_true+loss\_false+self.gradient\_penalty\_coefficient\*gradient\_penalty

#

වෙනත්ආකාරයකින් ශ්රේණියේ ද penalty ුවම මඟ හරින්න

53else:54loss=loss\_true+loss\_false

#

ලොග්දේවල්

57tracker.add("loss.discriminator.true.",loss\_true)58tracker.add("loss.discriminator.false.",loss\_false)59tracker.add("loss.discriminator.",loss)6061returnloss

#

64defmain():

#

වින්යාසවස්තුව සාදන්න

66conf=Configs()

#

අත්හදාබැලීම සාදන්න

68experiment.create(name='mnist\_wassertein\_gp\_dcgan')

#

වින්යාසයන්අභිබවා යන්න

70experiment.configs(conf,71{72'discriminator':'cnn',73'generator':'cnn',74'label\_smoothing':0.01,75'generator\_loss':'wasserstein',76'discriminator\_loss':'wasserstein',77'discriminator\_k':5,78})

#

අත්හදාබැලීම ආරම්භ කර පුහුණු ලූපය ක්රියාත්මක කරන්න

81withexperiment.start():82conf.run()838485if\_\_name\_\_=='\_\_main\_\_':86main()

Trending Research Paperslabml.ai