ColorizeTrainingVideo.ipynb
#NOTE: This must be the first call in order to work properly!
from deoldify import device
from deoldify.device_id import DeviceId
#choices: CPU, GPU0...GPU7
device.set(device=DeviceId.GPU0)
import os
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks.tensorboard import *
from fastai.vision.gan import *
from deoldify.generators import *
from deoldify.critics import *
from deoldify.dataset import *
from deoldify.loss import *
from deoldify.save import *
from deoldify.augs import noisify
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile
path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')
path_hr = path
path_lr = path/'bandw'
proj_id = 'VideoModel'
gen_name = proj_id + '_gen'
pre_gen_name = gen_name + '_0'
crit_name = proj_id + '_crit'
name_gen = proj_id + '_image_gen'
path_gen = path/name_gen
TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)
nf_factor = 2
xtra_tfms=[noisify(p=0.8)]
pct_start = 1e-8
def get_data(bs:int, sz:int, keep_pct:float):
return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr,
random_seed=None, keep_pct=keep_pct, xtra_tfms=xtra_tfms)
def get_crit_data(classes, bs, sz):
src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)
ll = src.label_from_folder(classes=classes)
data = (ll.transform(get_transforms(max_zoom=2.), size=sz)
.databunch(bs=bs).normalize(imagenet_stats))
return data
def create_training_images(fn,i):
dest = path_lr/fn.relative_to(path_hr)
dest.parent.mkdir(parents=True, exist_ok=True)
img = PIL.Image.open(fn).convert('LA').convert('RGB')
img.save(dest)
def save_preds(dl):
i=0
names = dl.dataset.items
for b in dl:
preds = learn_gen.pred_batch(batch=b, reconstruct=True)
for o in preds:
o.save(path_gen/names[i].name)
i += 1
def save_gen_images():
if path_gen.exists(): shutil.rmtree(path_gen)
path_gen.mkdir(exist_ok=True)
data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)
save_preds(data_gen.fix_dl)
PIL.Image.open(path_gen.ls()[0])
Only runs if the directory isn't already created.
if not path_lr.exists():
il = ImageList.from_folder(path_hr)
parallel(create_training_images, il.items)
bs=8
sz=192
keep_pct=0.25
data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)
learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))
learn_gen = learn_gen.load(pre_gen_name, with_opt=False)
learn_gen.unfreeze()
learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))
learn_gen.save(pre_gen_name)
Best results so far have been based only doing a single run of the cells below (otherwise glitches are introduced that are visible in video).
old_checkpoint_num = 0
checkpoint_num = old_checkpoint_num + 1
gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)
gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)
crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)
crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)
bs=8
sz=192
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)
save_gen_images()
bs=16
sz=192
learn_gen=None
gc.collect()
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)
learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)
learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))
learn_critic.fit_one_cycle(4, 1e-4)
learn_critic.save(crit_new_checkpoint_name)
learn_crit=None
learn_gen=None
gc.collect()
lr=5e-6
sz=192
bs=5
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)
learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,
opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100, stats_iters=10, loss_iters=1))
learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))
Find the checkpoint just before where glitches start to be introduced. So far this has been found at the point of iterating through 1.4% of the data when using learning rate of 1e-5, and at 2.2% of the data for 5e-6.
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)
learn_gen.freeze_to(-1)
learn.fit(1,lr)