ColorizeTrainingStableLargeBatch.ipynb
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks.tensorboard import *
from fastai.vision.gan import *
from deoldify.generators import *
from deoldify.critics import *
from deoldify.dataset import *
from deoldify.loss import *
from deoldify.save import *
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile
This will allow us to fit the model within a GPU with smaller memory capacity (e.g. GTX 1070 8Gb).
Large Model Support (LMS) is a feature provided in IBM Watson Machine Learning Community Edition (WML-CE) PyTorch V1.1.0 that allows the successful training of deep learning models that would otherwise exhaust GPU memory and abort with “out-of-memory” errors. LMS manages this oversubscription of GPU memory by temporarily swapping tensors to host memory when they are not needed. One or more elements of a deep learning model can lead to GPU memory exhaustion.
Requires the use of IBM WML-CE (Available here: https://www.ibm.com/support/knowledgecenter/en/SS5SF7_1.6.1/welcome/welcome.html)
Further Reading on PyTorch with Large Model Support: https://www.ibm.com/support/knowledgecenter/en/SS5SF7_1.6.1/navigation/wmlce_getstarted_pytorch.html
import shutil
# Set limit of GPU used before swapping to tensors to host memory
max_gpu_mem = 7
def gb_to_bytes(gb):
return gb*1024*1024*1024
# Enable PyTorch LMS
torch.cuda.set.enabled_lms(True)
# Set LMS limit
torch.cuda.set_limit_lms(gb_to_bytes(max_gpu_mem))
# Check LMS is enabled
torch.cuda.get_enabled_lms()
# Check LMS Limit has been set
torch.cuda.get_limit_lms()
# Path to Training Data
path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')
path_hr = path
# Path to Black and White images
path_bandw = Path('/training/DeOldify')
path_lr = path_bandw/'bandw'
# Name of Model
proj_id = 'StableModel'
# Name of Generator
gen_name = proj_id + '_gen'
pre_gen_name = gen_name + '_0'
# Name of Critic
crit_name = proj_id + '_crit'
# Name of Generated Images folder, located within the Black and White folder
name_gen = proj_id + '_image_gen'
path_gen = path/name_gen
# Path to tensorboard data
TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)
nf_factor = 2
pct_start = 1e-8
# Number of workers for DataLoader
num_works = 2
def get_data(bs:int, sz:int, keep_pct:float):
return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr,
random_seed=None, keep_pct=keep_pct, num_workers=num_works)
def get_crit_data(classes, bs, sz):
src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)
ll = src.label_from_folder(classes=classes)
data = (ll.transform(get_transforms(max_zoom=2.), size=sz)
.databunch(bs=bs).normalize(imagenet_stats))
return data
def create_training_images(fn,i):
dest = path_lr/fn.relative_to(path_hr)
dest.parent.mkdir(parents=True, exist_ok=True)
img = PIL.Image.open(fn).convert('LA').convert('RGB')
img.save(dest)
def save_preds(dl):
i=0
names = dl.dataset.items
for b in dl:
preds = learn_gen.pred_batch(batch=b, reconstruct=True)
for o in preds:
o.save(path_gen/names[i].name)
i += 1
def save_gen_images():
if path_gen.exists(): shutil.rmtree(path_gen)
path_gen.mkdir(exist_ok=True)
data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)
save_preds(data_gen.fix_dl)
PIL.Image.open(path_gen.ls()[0])
Only runs if the directory isn't already created.
if not path_lr.exists():
il = ImageList.from_folder(path_hr)
parallel(create_training_images, il.items)
Most of the training takes place here in pretraining for NoGAN. The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training.
bs=88 # This can be increased if using PyTorch LMS, training could be slower.
sz=64
keep_pct=1.0
data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)
learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))
learn_gen.fit_one_cycle(1, pct_start=0.8, max_lr=slice(1e-3))
learn_gen.save(pre_gen_name)
learn_gen.unfreeze()
learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(3e-7, 3e-4))
learn_gen.save(pre_gen_name)
bs=40 # This can be increased if using PyTorch LMS, training could be slower.
sz=128
keep_pct=1.0
learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)
learn_gen.unfreeze()
learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(1e-7,1e-4))
learn_gen.save(pre_gen_name)
bs=16 # This can be increased if using PyTorch LMS, training could be slower.
sz=192
keep_pct=0.50
learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)
learn_gen.unfreeze()
learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))
learn_gen.save(pre_gen_name)
bs=8 # This can be increased if using PyTorch LMS, training could be slower.
sz=256
keep_pct=0.50
learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)
learn_gen.unfreeze()
learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))
learn_gen.save(pre_gen_name)
Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality). Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old.
old_checkpoint_num = 0
checkpoint_num = old_checkpoint_num + 1
gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)
gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)
crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)
crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)
bs=8
sz=256
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)
save_gen_images()
if old_checkpoint_num == 0:
bs=64
sz=128
learn_gen=None
gc.collect()
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)
learn_critic = colorize_crit_learner(data=data_crit, nf=256)
learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))
learn_critic.fit_one_cycle(6, 1e-3)
learn_critic.save(crit_old_checkpoint_name)
bs=8
sz=256
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)
learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)
learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))
learn_critic.fit_one_cycle(4, 1e-4)
learn_critic.save(crit_new_checkpoint_name)
learn_crit=None
learn_gen=None
gc.collect()
lr=2e-5
sz=256
bs=5
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)
learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,
opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))
learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))
Find the checkpoint just before where glitches start to be introduced. This is all very new so you may need to play around with just how far you go here with keep_pct.
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)
learn_gen.freeze_to(-1)
learn.fit(1,lr)