Back to Annotated Deep Learning Paper Implementations

උත්පාදකඅහිතකර ජාල MNIST සමඟ අත්හදා බැලීම

docs/si/gan/original/experiment.html

latest9.3 KB
Original Source

homeganoriginal

View code on Github

#

උත්පාදකඅහිතකර ජාල MNIST සමඟ අත්හදා බැලීම

10fromtypingimportAny1112importtorch13importtorch.nnasnn14importtorch.utils.data15fromtorchvisionimporttransforms1617fromlabmlimporttracker,monit,experiment18fromlabml.configsimportoption,calculate19fromlabml\_helpers.datasets.mnistimportMNISTConfigs20fromlabml\_helpers.deviceimportDeviceConfigs21fromlabml\_helpers.moduleimportModule22fromlabml\_helpers.optimizerimportOptimizerConfigs23fromlabml\_helpers.train\_validimportTrainValidConfigs,hook\_model\_outputs,BatchIndex24fromlabml\_nn.gan.originalimportDiscriminatorLogitsLoss,GeneratorLogitsLoss

#

27defweights\_init(m):28classname=m.\_\_class\_\_.\_\_name\_\_29ifclassname.find('Linear')!=-1:30nn.init.normal\_(m.weight.data,0.0,0.02)31elifclassname.find('BatchNorm')!=-1:32nn.init.normal\_(m.weight.data,1.0,0.02)33nn.init.constant\_(m.bias.data,0)

#

සරලඑම්එල්පී උත්පාදක යන්ත්රය

මෙය LeakyReLU සක්රිය කිරීම් සමඟ ප්රමාණය වැඩි කිරීමේ රේඛීය ස්ථර තුනක් ඇත. අවසාන ස්ථරය tanh සක්රිය කිරීමක් ඇත.

36classGenerator(Module):

#

44def\_\_init\_\_(self):45super().\_\_init\_\_()46layer\_sizes=[256,512,1024]47layers=[]48d\_prev=10049forsizeinlayer\_sizes:50layers=layers+[nn.Linear(d\_prev,size),nn.LeakyReLU(0.2)]51d\_prev=size5253self.layers=nn.Sequential(\*layers,nn.Linear(d\_prev,28\*28),nn.Tanh())5455self.apply(weights\_init)

#

57defforward(self,x):58returnself.layers(x).view(x.shape[0],1,28,28)

#

සරලඑම්එල්පී වෙනස්කම් කරන්නා

මෙය LeakyReLU සක්රිය කිරීම් සමඟ ප්රමාණය අඩු කිරීමේ රේඛීය ස්ථර තුනක් ඇත. අවසාන ස්ථරයට තනි ප්රතිදානයක් ඇති අතර එමඟින් ආදානය සැබෑ හෝ ව්යාජ ද යන්න පිළිබඳ පිවිසුම ලබා දේ. එය සිග්මෝයිඩ් ගණනය කිරීමෙන් ඔබට සම්භාවිතාව ලබා ගත හැකිය.

61classDiscriminator(Module):

#

70def\_\_init\_\_(self):71super().\_\_init\_\_()72layer\_sizes=[1024,512,256]73layers=[]74d\_prev=28\*2875forsizeinlayer\_sizes:76layers=layers+[nn.Linear(d\_prev,size),nn.LeakyReLU(0.2)]77d\_prev=size7879self.layers=nn.Sequential(\*layers,nn.Linear(d\_prev,1))80self.apply(weights\_init)

#

82defforward(self,x):83returnself.layers(x.view(x.shape[0],-1))

#

වින්යාසකිරීම්

අපගේක්රියාත්මක කිරීම සරල කිරීම සඳහා දත්ත පැටවුම් සහ පුහුණු සහ වලංගු කිරීමේ ලූප වින්යාසයන් ලබා ගැනීම සඳහා මෙය MNIST වින්යාසයන් පුළුල් කරයි.

86classConfigs(MNISTConfigs,TrainValidConfigs):

#

94device:torch.device=DeviceConfigs()95dataset\_transforms='mnist\_gan\_transforms'96epochs:int=109798is\_save\_models=True99discriminator:Module='mlp'100generator:Module='mlp'101generator\_optimizer:torch.optim.Adam102discriminator\_optimizer:torch.optim.Adam103generator\_loss:GeneratorLogitsLoss='original'104discriminator\_loss:DiscriminatorLogitsLoss='original'105label\_smoothing:float=0.2106discriminator\_k:int=1

#

ආරම්භකකරණය

108definit(self):

#

112self.state\_modules=[]113114hook\_model\_outputs(self.mode,self.generator,'generator')115hook\_model\_outputs(self.mode,self.discriminator,'discriminator')116tracker.set\_scalar("loss.generator.\*",True)117tracker.set\_scalar("loss.discriminator.\*",True)118tracker.set\_image("generated",True,1/100)

#

z∼p(z)

120defsample\_z(self,batch\_size:int):

#

124returntorch.randn(batch\_size,100,device=self.device)

#

පුහුණුපියවරක් ගන්න

126defstep(self,batch:Any,batch\_idx:BatchIndex):

#

ආදර්ශතත්වයන් සකසන්න

132self.generator.train(self.mode.is\_train)133self.discriminator.train(self.mode.is\_train)

#

MNISTරූප ලබා ගන්න

136data=batch[0].to(self.device)

#

පුහුණුමාදිලියේ වර්ධක පියවර

139ifself.mode.is\_train:140tracker.add\_global\_step(len(data))

#

වෙනස්කම්කරන්නා පුහුණු කරන්න

143withmonit.section("discriminator"):

#

වෙනස්කම්කරන්නාගේ පාඩුව ලබා ගන්න

145loss=self.calc\_discriminator\_loss(data)

#

දුම්රිය

148ifself.mode.is\_train:149self.discriminator\_optimizer.zero\_grad()150loss.backward()151ifbatch\_idx.is\_last:152tracker.add('discriminator',self.discriminator)153self.discriminator\_optimizer.step()

#

සෑමවිටම උත්පාදක යන්ත්රය පුහුණු කරන්න discriminator_k

156ifbatch\_idx.is\_interval(self.discriminator\_k):157withmonit.section("generator"):158loss=self.calc\_generator\_loss(data.shape[0])

#

දුම්රිය

161ifself.mode.is\_train:162self.generator\_optimizer.zero\_grad()163loss.backward()164ifbatch\_idx.is\_last:165tracker.add('generator',self.generator)166self.generator\_optimizer.step()167168tracker.save()

#

වෙනස්කම්කරන්නාගේ පාඩුව ගණනය කරන්න

170defcalc\_discriminator\_loss(self,data):

#

174latent=self.sample\_z(data.shape[0])175logits\_true=self.discriminator(data)176logits\_false=self.discriminator(self.generator(latent).detach())177loss\_true,loss\_false=self.discriminator\_loss(logits\_true,logits\_false)178loss=loss\_true+loss\_false

#

ලොග්දේවල්

181tracker.add("loss.discriminator.true.",loss\_true)182tracker.add("loss.discriminator.false.",loss\_false)183tracker.add("loss.discriminator.",loss)184185returnloss

#

උත්පාදකඅලාභය ගණනය කරන්න

187defcalc\_generator\_loss(self,batch\_size:int):

#

191latent=self.sample\_z(batch\_size)192generated\_images=self.generator(latent)193logits=self.discriminator(generated\_images)194loss=self.generator\_loss(logits)

#

ලොග්දේවල්

197tracker.add('generated',generated\_images[0:6])198tracker.add("loss.generator.",loss)199200returnloss

#

205@option(Configs.dataset\_transforms)206defmnist\_gan\_transforms():207returntransforms.Compose([208transforms.ToTensor(),209transforms.Normalize((0.5,),(0.5,))210])211212213@option(Configs.discriminator\_optimizer)214def\_discriminator\_optimizer(c:Configs):215opt\_conf=OptimizerConfigs()216opt\_conf.optimizer='Adam'217opt\_conf.parameters=c.discriminator.parameters()218opt\_conf.learning\_rate=2.5e-4

#

ශ්රේණියේපළමු මොහොත සඳහා on ාතීය ක්ෂය වීමේ අනුපාතය සැකසීම වැදගත් 0.5 වේ. β1​ 0.9 අසමත් වීමේ පෙරනිමි.

222opt\_conf.betas=(0.5,0.999)223returnopt\_conf

#

226@option(Configs.generator\_optimizer)227def\_generator\_optimizer(c:Configs):228opt\_conf=OptimizerConfigs()229opt\_conf.optimizer='Adam'230opt\_conf.parameters=c.generator.parameters()231opt\_conf.learning\_rate=2.5e-4

#

ශ්රේණියේපළමු මොහොත සඳහා on ාතීය ක්ෂය වීමේ අනුපාතය සැකසීම වැදගත් 0.5 වේ. β1​ 0.9 අසමත් වීමේ පෙරනිමි.

235opt\_conf.betas=(0.5,0.999)236returnopt\_conf237238239calculate(Configs.generator,'mlp',lambdac:Generator().to(c.device))240calculate(Configs.discriminator,'mlp',lambdac:Discriminator().to(c.device))241calculate(Configs.generator\_loss,'original',lambdac:GeneratorLogitsLoss(c.label\_smoothing).to(c.device))242calculate(Configs.discriminator\_loss,'original',lambdac:DiscriminatorLogitsLoss(c.label\_smoothing).to(c.device))

#

245defmain():246conf=Configs()247experiment.create(name='mnist\_gan',comment='test')248experiment.configs(conf,249{'label\_smoothing':0.01})250withexperiment.start():251conf.run()252253254if\_\_name\_\_=='\_\_main\_\_':255main()

Trending Research Paperslabml.ai