Back to Annotated Deep Learning Paper Implementations

විසරණ සම්භාවිතා ආකෘති (ඩීඩීපීඑම්) පුහුණුව නිරූපණය කිරීම

docs/si/diffusion/ddpm/experiment.html

latest10.3 KB
Original Source

homediffusionddpm

View code on Github

#

විසරණ සම්භාවිතා ආකෘති (ඩීඩීපීඑම්) පුහුණුව නිරූපණය කිරීම

මෙය සෙලෙබා එච්කියු දත්ත කට්ටලය මත ඩීඩීපීඑම් පදනම් කරගත් ආකෘතියක් පුහුණු කරයි. fast.ai හි මෙම සාකච්ඡාවේදී බාගත කිරීමේ උපදෙස් ඔබට සොයාගත හැකිය. data/celebA ෆෝල්ඩරය තුළ පින්තූර සුරකින්න.

කඩදාසි ක ක්ෂය සමග ආදර්ශ ඝාතීය වෙනස්වන සාමාන්යය භාවිතා කර0.9999 ඇත. සරල බව සඳහා අපි මෙය මඟ හැර ඇත්තෙමු.

20fromtypingimportList2122importtorch23importtorch.utils.data24importtorchvision25fromPILimportImage2627fromlabmlimportlab,tracker,experiment,monit28fromlabml.configsimportBaseConfigs,option29fromlabml\_helpers.deviceimportDeviceConfigs30fromlabml\_nn.diffusion.ddpmimportDenoiseDiffusion31fromlabml\_nn.diffusion.ddpm.unetimportUNet

#

වින්යාසකිරීම්

34classConfigs(BaseConfigs):

#

ආකෘතියපුහුණු කිරීමේ උපකරණය. DeviceConfigs ලබා ගත හැකි CUDA උපාංගයක් අහුලනවා හෝ CPU කිරීමට පෙරනිමි.

41device:torch.device=DeviceConfigs()

#

සඳහාU-Net ආකෘතිය ϵθ​(xt​,t)

44eps\_model:UNet

#

ඩීඩීපීඑම් ඇල්ගොරිතම

46diffusion:DenoiseDiffusion

#

රූපයේනාලිකා ගණන. 3 RGB සඳහා.

49image\_channels:int=3

#

රූපප්රමාණය

51image\_size:int=32

#

ආරම්භකවිශේෂාංග සිතියමේ නාලිකා ගණන

53n\_channels:int=64

#

එක්එක් විභේදනයේ නාලිකා අංක ලැයිස්තුව. නාලිකා ගණන වේ channel_multipliers[i] * n_channels

56channel\_multipliers:List[int]=[1,2,2,4]

#

එක්එක් යෝජනාවේදී අවධානය භාවිතා කළ යුතුද යන්න පෙන්වන බූලියන් ලැයිස්තුව

58is\_attention:List[int]=[False,False,False,True]

#

කාලපියවර ගණන T

61n\_steps:int=1\_000

#

කණ්ඩායම්ප්රමාණය

63batch\_size:int=64

#

උත්පාදනයකිරීමට සාම්පල ගණන

65n\_samples:int=16

#

ඉගෙනුම්අනුපාතය

67learning\_rate:float=2e-5

#

පුහුණුඑපොච් ගණන

70epochs:int=1\_000

#

දත්තකට්ටලය

73dataset:torch.utils.data.Dataset

#

දත්තකාරකය

75data\_loader:torch.utils.data.DataLoader

#

ආදම්ප්රශස්තකරණය

78optimizer:torch.optim.Adam

#

80definit(self):

#

ϵθ​(xt​,t) ආකෘතිය සාදන්න

82self.eps\_model=UNet(83image\_channels=self.image\_channels,84n\_channels=self.n\_channels,85ch\_mults=self.channel\_multipliers,86is\_attn=self.is\_attention,87).to(self.device)

#

DDPM පන්ති නිර්මාණය

90self.diffusion=DenoiseDiffusion(91eps\_model=self.eps\_model,92n\_steps=self.n\_steps,93device=self.device,94)

#

දත්තකාරකය සාදන්න

97self.data\_loader=torch.utils.data.DataLoader(self.dataset,self.batch\_size,shuffle=True,pin\_memory=True)

#

ප්රශස්තකරණයසාදන්න

99self.optimizer=torch.optim.Adam(self.eps\_model.parameters(),lr=self.learning\_rate)

#

රූපලොග් වීම

102tracker.set\_image("sample",True)

#

නියැදිරූප

104defsample(self):

#

108withtorch.no\_grad():

#

xT​∼p(xT​)=N(xT​;0,I)

110x=torch.randn([self.n\_samples,self.image\_channels,self.image\_size,self.image\_size],111device=self.device)

#

T පියවර සඳහා ශබ්දය ඉවත් කරන්න

114fort\_inmonit.iterate('Sample',self.n\_steps):

#

t

116t=self.n\_steps-t\_-1

#

වෙතින්නියැදිය pθ​(xt−1​∣xt​)

118x=self.diffusion.p\_sample(x,x.new\_full((self.n\_samples,),t,dtype=torch.long))

#

ලොග්සාම්පල

121tracker.save('sample',x)

#

දුම්රිය

123deftrain(self):

#

දත්තසමුදාය හරහා නැවත කරන්න

129fordatainmonit.iterate('Train',self.data\_loader):

#

ගෝලීයපියවර වැඩි කිරීම

131tracker.add\_global\_step()

#

උපාංගයවෙත දත්ත ගෙනයන්න

133data=data.to(self.device)

#

අනුක්රමිකශුන්ය කරන්න

136self.optimizer.zero\_grad()

#

අලාභයගණනය කරන්න

138loss=self.diffusion.loss(data)

#

අනුක්රමිකගණනය

140loss.backward()

#

ප්රශස්තිකරණපියවරක් ගන්න

142self.optimizer.step()

#

අලාභයලුහුබඳින්න

144tracker.save('loss',loss)

#

පුහුණුලූපය

146defrun(self):

#

150for\_inmonit.loop(self.epochs):

#

ආකෘතියපුහුණු කරන්න

152self.train()

#

පින්තූරකිහිපයක් සාම්පල කරන්න

154self.sample()

#

කොන්සෝලයේනව රේඛාවක්

156tracker.new\_line()

#

ආකෘතියසුරකින්න

158experiment.save\_checkpoint()

#

සෙලෙබාමූලස්ථානය දත්ත කට්ටලය

161classCelebADataset(torch.utils.data.Dataset):

#

166def\_\_init\_\_(self,image\_size:int):167super().\_\_init\_\_()

#

සෙලෙබාපින්තූර ෆෝල්ඩරය

170folder=lab.get\_data\_path()/'celebA'

#

ගොනුලැයිස්තුව

172self.\_files=[pforpinfolder.glob(f'\*\*/\*.jpg')]

#

රූපයවෙනස් කර ටෙන්සර් බවට පරිවර්තනය කිරීම සඳහා පරිවර්තනයන්

175self.\_transform=torchvision.transforms.Compose([176torchvision.transforms.Resize(image\_size),177torchvision.transforms.ToTensor(),178])

#

දත්තසමුදාය ප්රමාණය

180def\_\_len\_\_(self):

#

184returnlen(self.\_files)

#

රූපයක්ලබා ගන්න

186def\_\_getitem\_\_(self,index:int):

#

190img=Image.open(self.\_files[index])191returnself.\_transform(img)

#

සෙලෙබාදත්ත කට්ටලය සාදන්න

194@option(Configs.dataset,'CelebA')195defceleb\_dataset(c:Configs):

#

199returnCelebADataset(c.image\_size)

#

MNISTදත්ත කට්ටලය

202classMNISTDataset(torchvision.datasets.MNIST):

#

207def\_\_init\_\_(self,image\_size):208transform=torchvision.transforms.Compose([209torchvision.transforms.Resize(image\_size),210torchvision.transforms.ToTensor(),211])212213super().\_\_init\_\_(str(lab.get\_data\_path()),train=True,download=True,transform=transform)

#

215def\_\_getitem\_\_(self,item):216returnsuper().\_\_getitem\_\_(item)[0]

#

MNISTදත්ත සමුදාය සාදන්න

219@option(Configs.dataset,'MNIST')220defmnist\_dataset(c:Configs):

#

224returnMNISTDataset(c.image\_size)

#

227defmain():

#

අත්හදාබැලීම සාදන්න

229experiment.create(name='diffuse',writers={'screen','labml'})

#

වින්යාසයන්සාදන්න

232configs=Configs()

#

වින්යාසයන්සකසන්න. ශබ්දකෝෂයේ අගයන් සම්මත කිරීමෙන් ඔබට පෙරනිමි අභිබවා යා හැකිය.

235experiment.configs(configs,{236'dataset':'CelebA',# 'MNIST'237'image\_channels':3,# 1,238'epochs':100,# 5,239})

#

ආරම්භකරන්න

242configs.init()

#

ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න

245experiment.add\_pytorch\_models({'eps\_model':configs.eps\_model})

#

පුහුණුලූපය ආරම්භ කර ක්රියාත්මක කරන්න

248withexperiment.start():249configs.run()

#

253if\_\_name\_\_=='\_\_main\_\_':254main()

Trending Research Paperslabml.ai