Back to Deoldify

ColorizeTrainingVideo

ColorizeTrainingVideo.ipynb

latest6.1 KB
Original Source

Video Model Training

NOTES:

  • It's assumed that there's a pretrained generator from the ColorizeTrainingStable notebook available at the specified path.
  • This is "NoGAN" based training, described in the DeOldify readme.
python
#NOTE:  This must be the first call in order to work properly!
from deoldify import device
from deoldify.device_id import DeviceId
#choices:  CPU, GPU0...GPU7
device.set(device=DeviceId.GPU0)
python
import os
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks.tensorboard import *
from fastai.vision.gan import *
from deoldify.generators import *
from deoldify.critics import *
from deoldify.dataset import *
from deoldify.loss import *
from deoldify.save import *
from deoldify.augs import noisify 
from PIL import Image, ImageDraw, ImageFont
from PIL import ImageFile

Setup

python
path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')
path_hr = path
path_lr = path/'bandw'

proj_id = 'VideoModel'
gen_name = proj_id + '_gen'
pre_gen_name = gen_name + '_0'
crit_name = proj_id + '_crit'

name_gen = proj_id + '_image_gen'
path_gen = path/name_gen

TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)

nf_factor = 2
xtra_tfms=[noisify(p=0.8)]
pct_start = 1e-8
python
def get_data(bs:int, sz:int, keep_pct:float):
    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, 
                             random_seed=None, keep_pct=keep_pct, xtra_tfms=xtra_tfms)

def get_crit_data(classes, bs, sz):
    src = ImageList.from_folder(path, include=classes, recurse=True).split_by_rand_pct(0.1, seed=42)
    ll = src.label_from_folder(classes=classes)
    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)
           .databunch(bs=bs).normalize(imagenet_stats))
    return data

def create_training_images(fn,i):
    dest = path_lr/fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn).convert('LA').convert('RGB')
    img.save(dest)  
    
def save_preds(dl):
    i=0
    names = dl.dataset.items
    
    for b in dl:
        preds = learn_gen.pred_batch(batch=b, reconstruct=True)
        for o in preds:
            o.save(path_gen/names[i].name)
            i += 1
            
def save_gen_images():
    if path_gen.exists(): shutil.rmtree(path_gen)
    path_gen.mkdir(exist_ok=True)
    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)
    save_preds(data_gen.fix_dl)
    PIL.Image.open(path_gen.ls()[0])

Create black and white training images

Only runs if the directory isn't already created.

python
if not path_lr.exists():
    il = ImageList.from_folder(path_hr)
    parallel(create_training_images, il.items)

Finetune Generator With Noise Augmented Images.

This helps the generator better deal with noisy/grainy video (which is pretty normal).
python
bs=8
sz=192
keep_pct=0.25
python
data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)
python
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)
python
learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))
python
learn_gen = learn_gen.load(pre_gen_name, with_opt=False)
python
learn_gen.unfreeze()
python
learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))
python
learn_gen.save(pre_gen_name)

Repeatable GAN Cycle

NOTE

Best results so far have been based only doing a single run of the cells below (otherwise glitches are introduced that are visible in video).

python
old_checkpoint_num = 0
checkpoint_num = old_checkpoint_num + 1
gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)
gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)
crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)
crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)

Save Generated Images

python
bs=8
sz=192
python
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)
python
save_gen_images()

Pretrain Critic

python
bs=16
sz=192
python
learn_gen=None
gc.collect()
python
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)
python
data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)
python
learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)
python
learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))
python
learn_critic.fit_one_cycle(4, 1e-4)
python
learn_critic.save(crit_new_checkpoint_name)

GAN

python
learn_crit=None
learn_gen=None
gc.collect()
python
lr=5e-6
sz=192
bs=5
python
data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)
python
learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)
python
learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)
python
switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)
learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,
                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)
learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))
learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100, stats_iters=10, loss_iters=1))
learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))

Instructions:

Find the checkpoint just before where glitches start to be introduced. So far this has been found at the point of iterating through 1.4% of the data when using learning rate of 1e-5, and at 2.2% of the data for 5e-6.

python
learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)
learn_gen.freeze_to(-1)
learn.fit(1,lr)