docs/gan/original/experiment.html
10fromtypingimportAny1112fromtorchvisionimporttransforms1314importtorch15importtorch.nnasnn16importtorch.utils.data17fromlabmlimporttracker,monit,experiment18fromlabml.configsimportoption,calculate19fromlabml\_nn.gan.originalimportDiscriminatorLogitsLoss,GeneratorLogitsLoss20fromlabml\_nn.helpers.datasetsimportMNISTConfigs21fromlabml\_nn.helpers.deviceimportDeviceConfigs22fromlabml\_nn.helpers.optimizerimportOptimizerConfigs23fromlabml\_nn.helpers.trainerimportTrainValidConfigs,BatchIndex
26defweights\_init(m):27classname=m.\_\_class\_\_.\_\_name\_\_28ifclassname.find('Linear')!=-1:29nn.init.normal\_(m.weight.data,0.0,0.02)30elifclassname.find('BatchNorm')!=-1:31nn.init.normal\_(m.weight.data,1.0,0.02)32nn.init.constant\_(m.bias.data,0)
This has three linear layers of increasing size with LeakyReLU activations. The final layer has a tanh activation.
35classGenerator(nn.Module):
43def\_\_init\_\_(self):44super().\_\_init\_\_()45layer\_sizes=[256,512,1024]46layers=[]47d\_prev=10048forsizeinlayer\_sizes:49layers=layers+[nn.Linear(d\_prev,size),nn.LeakyReLU(0.2)]50d\_prev=size5152self.layers=nn.Sequential(\*layers,nn.Linear(d\_prev,28\*28),nn.Tanh())5354self.apply(weights\_init)
56defforward(self,x):57returnself.layers(x).view(x.shape[0],1,28,28)
This has three linear layers of decreasing size with LeakyReLU activations. The final layer has a single output that gives the logit of whether input is real or fake. You can get the probability by calculating the sigmoid of it.
60classDiscriminator(nn.Module):
69def\_\_init\_\_(self):70super().\_\_init\_\_()71layer\_sizes=[1024,512,256]72layers=[]73d\_prev=28\*2874forsizeinlayer\_sizes:75layers=layers+[nn.Linear(d\_prev,size),nn.LeakyReLU(0.2)]76d\_prev=size7778self.layers=nn.Sequential(\*layers,nn.Linear(d\_prev,1))79self.apply(weights\_init)
81defforward(self,x):82returnself.layers(x.view(x.shape[0],-1))
This extends MNIST configurations to get the data loaders and Training and validation loop configurations to simplify our implementation.
85classConfigs(MNISTConfigs,TrainValidConfigs):
93device:torch.device=DeviceConfigs()94dataset\_transforms='mnist\_gan\_transforms'95epochs:int=109697is\_save\_models=True98discriminator:nn.Module='mlp'99generator:nn.Module='mlp'100generator\_optimizer:torch.optim.Adam101discriminator\_optimizer:torch.optim.Adam102generator\_loss:GeneratorLogitsLoss='original'103discriminator\_loss:DiscriminatorLogitsLoss='original'104label\_smoothing:float=0.2105discriminator\_k:int=1
Initializations
107definit(self):
111self.state\_modules=[]112113tracker.set\_scalar("loss.generator.\*",True)114tracker.set\_scalar("loss.discriminator.\*",True)115tracker.set\_image("generated",True,1/100)
z∼p(z)
117defsample\_z(self,batch\_size:int):
121returntorch.randn(batch\_size,100,device=self.device)
Take a training step
123defstep(self,batch:Any,batch\_idx:BatchIndex):
Set model states
129self.generator.train(self.mode.is\_train)130self.discriminator.train(self.mode.is\_train)
Get MNIST images
133data=batch[0].to(self.device)
Increment step in training mode
136ifself.mode.is\_train:137tracker.add\_global\_step(len(data))
Train the discriminator
140withmonit.section("discriminator"):
Get discriminator loss
142loss=self.calc\_discriminator\_loss(data)
Train
145ifself.mode.is\_train:146self.discriminator\_optimizer.zero\_grad()147loss.backward()148ifbatch\_idx.is\_last:149tracker.add('discriminator',self.discriminator)150self.discriminator\_optimizer.step()
Train the generator once in every discriminator_k
153ifbatch\_idx.is\_interval(self.discriminator\_k):154withmonit.section("generator"):155loss=self.calc\_generator\_loss(data.shape[0])
Train
158ifself.mode.is\_train:159self.generator\_optimizer.zero\_grad()160loss.backward()161ifbatch\_idx.is\_last:162tracker.add('generator',self.generator)163self.generator\_optimizer.step()164165tracker.save()
Calculate discriminator loss
167defcalc\_discriminator\_loss(self,data):
171latent=self.sample\_z(data.shape[0])172logits\_true=self.discriminator(data)173logits\_false=self.discriminator(self.generator(latent).detach())174loss\_true,loss\_false=self.discriminator\_loss(logits\_true,logits\_false)175loss=loss\_true+loss\_false
Log stuff
178tracker.add("loss.discriminator.true.",loss\_true)179tracker.add("loss.discriminator.false.",loss\_false)180tracker.add("loss.discriminator.",loss)181182returnloss
Calculate generator loss
184defcalc\_generator\_loss(self,batch\_size:int):
188latent=self.sample\_z(batch\_size)189generated\_images=self.generator(latent)190logits=self.discriminator(generated\_images)191loss=self.generator\_loss(logits)
Log stuff
194tracker.add('generated',generated\_images[0:6])195tracker.add("loss.generator.",loss)196197returnloss
200@option(Configs.dataset\_transforms)201defmnist\_gan\_transforms():202returntransforms.Compose([203transforms.ToTensor(),204transforms.Normalize((0.5,),(0.5,))205])206207208@option(Configs.discriminator\_optimizer)209def\_discriminator\_optimizer(c:Configs):210opt\_conf=OptimizerConfigs()211opt\_conf.optimizer='Adam'212opt\_conf.parameters=c.discriminator.parameters()213opt\_conf.learning\_rate=2.5e-4
Setting exponent decay rate for first moment of gradient, β1 to 0.5 is important. Default of 0.9 fails.
217opt\_conf.betas=(0.5,0.999)218returnopt\_conf
221@option(Configs.generator\_optimizer)222def\_generator\_optimizer(c:Configs):223opt\_conf=OptimizerConfigs()224opt\_conf.optimizer='Adam'225opt\_conf.parameters=c.generator.parameters()226opt\_conf.learning\_rate=2.5e-4
Setting exponent decay rate for first moment of gradient, β1 to 0.5 is important. Default of 0.9 fails.
230opt\_conf.betas=(0.5,0.999)231returnopt\_conf232233234calculate(Configs.generator,'mlp',lambdac:Generator().to(c.device))235calculate(Configs.discriminator,'mlp',lambdac:Discriminator().to(c.device))236calculate(Configs.generator\_loss,'original',lambdac:GeneratorLogitsLoss(c.label\_smoothing).to(c.device))237calculate(Configs.discriminator\_loss,'original',lambdac:DiscriminatorLogitsLoss(c.label\_smoothing).to(c.device))
240defmain():241conf=Configs()242experiment.create(name='mnist\_gan',comment='test')243experiment.configs(conf,244{'label\_smoothing':0.01})245withexperiment.start():246conf.run()247248249if\_\_name\_\_=='\_\_main\_\_':250main()