Back to Annotated Deep Learning Paper Implementations

U-ශුද්ධපුහුණු

docs/si/unet/experiment.html

latest7.9 KB
Original Source

homeunet

View code on Github

#

U-ශුද්ධපුහුණු

මෙය කාර්වානා දත්ත කට්ටලයේ යූ-නෙට් ආකෘතියක් පුහුණු කරයි. ඔබට බාගත කිරීමේ උපදෙස් සොයාගත හැකිය Kaggle.

carvana/train ෆෝල්ඩරය තුළ පුහුණු පින්තූර සහ carvana/train_masks ෆෝල්ඩරයේ වෙස් මුහුණු සුරකින්න.

සරලබව සඳහා, අපි පුහුණුවක් සහ වලංගු භේදයක් නොකරමු.

19importnumpyasnp20importtorch21importtorch.utils.data22importtorchvision.transforms.functional23fromtorchimportnn2425fromlabmlimportlab,tracker,experiment,monit26fromlabml.configsimportBaseConfigs27fromlabml\_helpers.deviceimportDeviceConfigs28fromlabml\_nn.unet.carvanaimportCarvanaDataset29fromlabml\_nn.unetimportUNet

#

වින්යාසකිරීම්

32classConfigs(BaseConfigs):

#

ආකෘතියපුහුණු කිරීමේ උපකරණය. DeviceConfigs ලබා ගත හැකි CUDA උපාංගයක් අහුලනවා හෝ CPU කිරීමට පෙරනිමි.

39device:torch.device=DeviceConfigs()

#

යූ-නෙට් ආකෘතිය

42model:UNet

#

රූපයේනාලිකා ගණන. 3 RGB සඳහා.

45image\_channels:int=3

#

නිමැවුම්ආවරණයේ නාලිකා ගණන. 1 ද්විමය ආවරණ සඳහා.

47mask\_channels:int=1

#

කණ්ඩායම්ප්රමාණය

50batch\_size:int=1

#

ඉගෙනුම්අනුපාතය

52learning\_rate:float=2.5e-4

#

පුහුණුඑපොච් ගණන

55epochs:int=4

#

දත්තකට්ටලය

58dataset:CarvanaDataset

#

දත්තකාරකය

60data\_loader:torch.utils.data.DataLoader

#

පාඩුශ්රිතය

63loss\_func=nn.BCELoss()

#

ද්විමයවර්ගීකරණය සඳහා සිග්මෝයිඩ් ශ්රිතය

65sigmoid=nn.Sigmoid()

#

ආදම්ප්රශස්තකරණය

68optimizer:torch.optim.Adam

#

70definit(self):

#

කාර්වානා දත්ත කට්ටලය ආරම්භ කරන්න

72self.dataset=CarvanaDataset(lab.get\_data\_path()/'carvana'/'train',73lab.get\_data\_path()/'carvana'/'train\_masks')

#

ආකෘතියආරම්භ කරන්න

75self.model=UNet(self.image\_channels,self.mask\_channels).to(self.device)

#

දත්තකාරකය සාදන්න

78self.data\_loader=torch.utils.data.DataLoader(self.dataset,self.batch\_size,79shuffle=True,pin\_memory=True)

#

ප්රශස්තකරණයසාදන්න

81self.optimizer=torch.optim.Adam(self.model.parameters(),lr=self.learning\_rate)

#

රූපලොග් වීම

84tracker.set\_image("sample",True)

#

නියැදිරූප

[email protected]\_grad()87defsample(self,idx=-1):

#

අහඹුනියැදියක් ලබා ගන්න

93x,\_=self.dataset[np.random.randint(len(self.dataset))]

#

උපාංගයවෙත දත්ත ගෙනයන්න

95x=x.to(self.device)

#

පුරෝකථනයකළ වෙස් මුහුණ ලබා ගන්න

98mask=self.sigmoid(self.model(x[None,:]))

#

වෙස්මුහුණෙහි ප්රමාණයට රූපය වගා කරන්න

100x=torchvision.transforms.functional.center\_crop(x,[mask.shape[2],mask.shape[3]])

#

ලොග්සාම්පල

102tracker.save('sample',x\*mask)

#

කepoch සඳහා දුම්රිය

104deftrain(self):

#

දත්තකට්ටලය හරහා නැවත කරන්න. එක් mix 50 යුගයකට නියැදි වේලාවන් භාවිතා කරන්න.

112for\_,(image,mask)inmonit.mix(('Train',self.data\_loader),(self.sample,list(range(50)))):

#

ගෝලීයපියවර වැඩි කිරීම

114tracker.add\_global\_step()

#

උපාංගයවෙත දත්ත ගෙනයන්න

116image,mask=image.to(self.device),mask.to(self.device)

#

අනුක්රමිකශුන්ය කරන්න

119self.optimizer.zero\_grad()

#

පුරෝකථනයකරන ලද වෙස්මුහුණු පිවිසුම් ලබා ගන්න

121logits=self.model(image)

#

ඉලක්කගතවෙස්මුහුණ පිවිසුම් ප්රමාණයට වගා කරන්න. යූ-නෙට් හි සංවලිත ස්ථර වල අපි පෑඩින් භාවිතා නොකරන්නේ නම් පිවිසුම් ප්රමාණය කුඩා වේ.

124mask=torchvision.transforms.functional.center\_crop(mask,[logits.shape[2],logits.shape[3]])

#

අලාභයගණනය කරන්න

126loss=self.loss\_func(self.sigmoid(logits),mask)

#

අනුක්රමිකගණනය

128loss.backward()

#

ප්රශස්තිකරණපියවරක් ගන්න

130self.optimizer.step()

#

අලාභයලුහුබඳින්න

132tracker.save('loss',loss)

#

පුහුණුලූපය

134defrun(self):

#

138for\_inmonit.loop(self.epochs):

#

ආකෘතියපුහුණු කරන්න

140self.train()

#

කොන්සෝලයේනව රේඛාවක්

142tracker.new\_line()

#

ආකෘතියසුරකින්න

144experiment.save\_checkpoint()

#

147defmain():

#

අත්හදාබැලීම සාදන්න

149experiment.create(name='unet')

#

වින්යාසයන්සාදන්න

152configs=Configs()

#

වින්යාසයන්සකසන්න. ශබ්දකෝෂයේ අගයන් සම්මත කිරීමෙන් ඔබට පෙරනිමි අභිබවා යා හැකිය.

155experiment.configs(configs,{})

#

ආරම්භකරන්න

158configs.init()

#

ඉතිරිකිරීම සහ පැටවීම සඳහා ආකෘති සකසන්න

161experiment.add\_pytorch\_models({'model':configs.model})

#

පුහුණුලූපය ආරම්භ කර ක්රියාත්මක කරන්න

164withexperiment.start():165configs.run()

#

169if\_\_name\_\_=='\_\_main\_\_':170main()

Trending Research Paperslabml.ai