code/artificial_intelligence/src/yolo_v1/yolo-v1-implementation.ipynb
import torch.nn as nn
import torch
#for dataset loading
import os
import pandas as pd
from PIL import Image
#for training
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.transforms.functional as FT
from tqdm import tqdm #for progress bar
from torch.utils.data import DataLoader
#for visualization
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from collections import Counter
#For checkpoint
import time
import sys
import datetime
seed = 123
torch.manual_seed(seed)
# Hyperparameters
LEARNING_RATE = 2e-5
DEVICE = "cuda" if torch.cuda.is_available else "cpu"
BATCH_SIZE = 16 # 64 in original paper but setting it to 16 due to low computation power
WEIGHT_DECAY = 0
EPOCHS = 1000
NUM_WORKERS = 2
PIN_MEMORY = True
LOAD_MODEL = False #True, if you wanna use existing trained model at LOAD_MODEL_FILE path, else set to False
LOAD_MODEL_FILE = "/kaggle/working/best_model.pth.tar"
IS_TRAIN = True #True, if you wanna train model else set to False for testing
#dataset paths
"""
Change root directory according to your file directory.
For kaggle keep as default.
For google colab, mount desired google drive, and set path to 'gdrive/My Drive/<your-dataset-file>'
For local PC, set value according to your file directory after downloading the dataset from https://www.kaggle.com/dataset/734b7bcb7ef13a045cbdd007a3c19874c2586ed0b02b4afc86126e89d00af8d2.
"""
ROOT_DIR = '/kaggle/input/pascalvoc-yolo/'
IMG_DIR = ROOT_DIR + "images"
LABEL_DIR = ROOT_DIR + "labels"
FULL_DATASET_FILEPATH = ROOT_DIR + "train.csv"
EIGHT_EXAMPLE_DATASET_FILEPATH = ROOT_DIR + "8examples.csv"
HUNDRED_EXAMPLE_DATASET_FILEPATH = ROOT_DIR + "100examples.csv"
TEST_DATASET_FILEPATH = ROOT_DIR + "test.csv"
#class labels
CLASS_LABELS = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "table", "dog", "horse", "motorbike", "person", "potted plant", "sheep", "sofa", "train", "tv/monitor"]
assert len(CLASS_LABELS) == 20, "Some classes are missing from class labels!"
"Class labels are correct!"
#refer YOLO V1 paper for the architecture
#architecture excluding fully connected layer output
architecture_config = [
#Tuple: (kernel_size, num_filters, stride, padding)
(7, 64, 2, 3),
#String: stands for maxpool layer
"M",
(3, 192, 1, 1),
"M",
(1,128, 1, 0),
(3,256, 1, 1),
(1,256, 1, 0),
(3,512, 1, 1),
"M",
#List: [#Tuple,.., int represents no of repeats]
[(1, 256, 1, 0), (3, 512, 1, 1), 4],
(1, 512, 1, 0),
(3, 1024, 1, 1),
"M",
[(1, 512, 1, 0), (3, 1024, 1, 1), 2],
(3, 1024, 1, 1),
(3, 1024, 2, 1),
(3, 1024, 1, 1),
(3, 1024, 1, 1)
]
#represents each convolution layer with a conv layer, batch norm and a relu
class CNNBlock(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(CNNBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias = False, **kwargs)
self.batchnorm = nn.BatchNorm2d(out_channels) #wasn't used in YOLO V1, but this one gives better results
self.leakyrelu = nn.LeakyReLU(0.1)
def forward(self, x):
return self.leakyrelu(self.batchnorm(self.conv(x)))
class Yolov1(nn.Module):
def __init__(self, in_channels= 3, **kwargs):
super(Yolov1, self).__init__()
self.architecture = architecture_config
self.in_channels = in_channels
#the entire convolution section is referred as darknet, so darknet + fully connected layer = YOLO
self.darknet = self._create_conv_layers(self.architecture)
self.fcs = self._create_fcs(**kwargs)
def forward(self, x):
x = self.darknet(x)
return self.fcs(torch.flatten(x, start_dim = 1)) # starting from 1 because 0th dimension contains no of examples
def _create_conv_layers(self, architecture):
layers = [] # add all layers to thsi list
in_channels = self.in_channels
for x in architecture:
if(type(x) == tuple):
# Tuple structure for ref: (kernel_size(0), num_filters(1), stride(2), padding(3))
layers += [CNNBlock(
in_channels,
x[1],
kernel_size = x[0],
stride = x[2],
padding = x[3]
)]
in_channels = x[1]
elif (type(x) == str):
layers += [nn.MaxPool2d(kernel_size = 2, stride = 2)]
elif (type(x) == list):
#List structure for ref: [#Tuple,.., int represents no of repeats]
num_repeats = x[2] #int
num_conv = len(x) - 1
for _ in range(num_repeats):
for i in range(num_conv):
layers += [CNNBlock(
in_channels,
x[i][1],
kernel_size = x[i][0],
stride = x[i][2],
padding = x[i][3],
)]
in_channels = x[i][1]
return nn.Sequential(*layers)
def _create_fcs(self, split_size, num_boxes, num_classes):
S, B, C = split_size, num_boxes, num_classes
return nn.Sequential(
nn.Flatten(),
nn.Linear(1024 * S * S, 496), #originally its 4096, instead of 496
nn.Dropout(0.0), #originally 0.5
nn.LeakyReLU(0.1),
nn.Linear(496, S * S * (C + B * 5)), #originally its 4096, instead of 496, (S , S, 30) C+B*5 = 30
)
#DO NOT EDIT THIS, USE THIS FOR CONFIRMING THE YOLO MODEL
def test(S = 7, B = 2, C = 20):
model = Yolov1(split_size = S, num_boxes = B, num_classes = C)
x = torch.randn((2, 3, 448, 448))
assert model(x).shape == torch.Tensor(2, 1470).shape, "Something wrong with the model!"
print('Model is correct!')
test()
#Calculates intersection over union(iou)
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
if box_format == "midpoint":
box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2
box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2
if box_format == "corners":
box1_x1 = boxes_preds[..., 0:1]
box1_y1 = boxes_preds[..., 1:2]
box1_x2 = boxes_preds[..., 2:3]
box1_y2 = boxes_preds[..., 3:4] # (N, 1)
box2_x1 = boxes_labels[..., 0:1]
box2_y1 = boxes_labels[..., 1:2]
box2_x2 = boxes_labels[..., 2:3]
box2_y2 = boxes_labels[..., 3:4]
x1 = torch.max(box1_x1, box2_x1)
y1 = torch.max(box1_y1, box2_y1)
x2 = torch.min(box1_x2, box2_x2)
y2 = torch.min(box1_y2, box2_y2)
# .clamp(0) is for the case when they do not intersect
intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)
box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))
return intersection / (box1_area + box2_area - intersection + 1e-6)
#For Non Max Suppression given bboxes
def non_max_suppression(bboxes, iou_threshold, threshold, box_format="corners"):
assert type(bboxes) == list
bboxes = [box for box in bboxes if box[1] > threshold]
bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)
bboxes_after_nms = []
while bboxes:
chosen_box = bboxes.pop(0)
bboxes = [
box
for box in bboxes
if box[0] != chosen_box[0]
or intersection_over_union(
torch.tensor(chosen_box[2:]),
torch.tensor(box[2:]),
box_format=box_format,
)
< iou_threshold
]
bboxes_after_nms.append(chosen_box)
return bboxes_after_nms
#for calculating mean average precision
def mean_average_precision(
pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20):
# list storing all AP for respective classes
average_precisions = []
# used for numerical stability later on
epsilon = 1e-6
for c in range(num_classes):
detections = []
ground_truths = []
# Go through all predictions and targets,
# and only add the ones that belong to the
# current class c
for detection in pred_boxes:
if detection[1] == c:
detections.append(detection)
for true_box in true_boxes:
if true_box[1] == c:
ground_truths.append(true_box)
# find the amount of bboxes for each training example
# Counter here finds how many ground truth bboxes we get
# for each training example, so let's say img 0 has 3,
# img 1 has 5 then we will obtain a dictionary with:
# amount_bboxes = {0:3, 1:5}
amount_bboxes = Counter([gt[0] for gt in ground_truths])
# We then go through each key, val in this dictionary
# and convert to the following (w.r.t same example):
# ammount_bboxes = {0:torch.tensor[0,0,0], 1:torch.tensor[0,0,0,0,0]}
for key, val in amount_bboxes.items():
amount_bboxes[key] = torch.zeros(val)
# sort by box probabilities which is index 2
detections.sort(key=lambda x: x[2], reverse=True)
TP = torch.zeros((len(detections)))
FP = torch.zeros((len(detections)))
total_true_bboxes = len(ground_truths)
# If none exists for this class then we can safely skip
if total_true_bboxes == 0:
continue
for detection_idx, detection in enumerate(detections):
# Only take out the ground_truths that have the same
# training idx as detection
ground_truth_img = [
bbox for bbox in ground_truths if bbox[0] == detection[0]
]
num_gts = len(ground_truth_img)
best_iou = 0
for idx, gt in enumerate(ground_truth_img):
iou = intersection_over_union(
torch.tensor(detection[3:]),
torch.tensor(gt[3:]),
box_format=box_format,
)
if iou > best_iou:
best_iou = iou
best_gt_idx = idx
if best_iou > iou_threshold:
# only detect ground truth detection once
if amount_bboxes[detection[0]][best_gt_idx] == 0:
# true positive and add this bounding box to seen
TP[detection_idx] = 1
amount_bboxes[detection[0]][best_gt_idx] = 1
else:
FP[detection_idx] = 1
# if IOU is lower then the detection is a false positive
else:
FP[detection_idx] = 1
TP_cumsum = torch.cumsum(TP, dim=0)
FP_cumsum = torch.cumsum(FP, dim=0)
recalls = TP_cumsum / (total_true_bboxes + epsilon)
precisions = torch.divide(TP_cumsum, (TP_cumsum + FP_cumsum + epsilon))
precisions = torch.cat((torch.tensor([1]), precisions))
recalls = torch.cat((torch.tensor([0]), recalls))
# torch.trapz for numerical integration
average_precisions.append(torch.trapz(precisions, recalls))
return sum(average_precisions) / len(average_precisions)
#Plots bounding boxes on the image
def plot_image(image, boxes):
im = np.array(image)
height, width, _ = im.shape
# Create figure and axes
fig, ax = plt.subplots(1)
# Display the image
ax.imshow(im)
# box[0] is x midpoint, box[2] is width
# box[1] is y midpoint, box[3] is height
# Create a Rectangle potch
for box in boxes:
class_label = CLASS_LABELS[int(box[0])]
score = box[1]
box = box[2:]
assert len(box) == 4, "Got more values than in x, y, w, h, in a box!" #to ensure there are only 4 elements in the box dimensions
upper_left_x = box[0] - box[2] / 2
upper_left_y = box[1] - box[3] / 2
rect = patches.Rectangle(
(upper_left_x * width, upper_left_y * height),
box[2] * width,
box[3] * height,
linewidth=1,
edgecolor="r",
facecolor="none",
)
label = "%s (%.3f)" % (class_label, score)
plt.text(upper_left_x * width, upper_left_y * height, label, color='white', bbox = dict(facecolor='red', alpha=0.5, edgecolor='red'))
# Add the patch to the Axes
ax.add_patch(rect)
plt.tight_layout()
plt.show()
def get_bboxes(
loader,
model,
iou_threshold,
threshold,
pred_format="cells",
box_format="midpoint",
device="cuda",
):
all_pred_boxes = []
all_true_boxes = []
# make sure model is in eval before get bboxes
model.eval()
train_idx = 0
for batch_idx, (x, labels) in enumerate(loader):
x = x.to(device)
labels = labels.to(device)
with torch.no_grad():
predictions = model(x)
batch_size = x.shape[0]
true_bboxes = cellboxes_to_boxes(labels)
bboxes = cellboxes_to_boxes(predictions)
for idx in range(batch_size):
nms_boxes = non_max_suppression(
bboxes[idx],
iou_threshold=iou_threshold,
threshold=threshold,
box_format=box_format,
)
#if batch_idx == 0 and idx == 0:
# plot_image(x[idx].permute(1,2,0).to("cpu"), nms_boxes)
# print(nms_boxes)
for nms_box in nms_boxes:
all_pred_boxes.append([train_idx] + nms_box)
for box in true_bboxes[idx]:
# many will get converted to 0 pred
if box[1] > threshold:
all_true_boxes.append([train_idx] + box)
train_idx += 1
model.train()
return all_pred_boxes, all_true_boxes
#for output of YOLO
def convert_cellboxes(predictions, S=7):
predictions = predictions.to("cpu")
batch_size = predictions.shape[0]
predictions = predictions.reshape(batch_size, 7, 7, 30)
bboxes1 = predictions[..., 21:25]
bboxes2 = predictions[..., 26:30]
scores = torch.cat(
(predictions[..., 20].unsqueeze(0), predictions[..., 25].unsqueeze(0)), dim=0
)
best_box = scores.argmax(0).unsqueeze(-1)
best_boxes = bboxes1 * (1 - best_box) + best_box * bboxes2
cell_indices = torch.arange(7).repeat(batch_size, 7, 1).unsqueeze(-1)
x = 1 / S * (best_boxes[..., :1] + cell_indices)
y = 1 / S * (best_boxes[..., 1:2] + cell_indices.permute(0, 2, 1, 3))
w_y = 1 / S * best_boxes[..., 2:4]
converted_bboxes = torch.cat((x, y, w_y), dim=-1)
predicted_class = predictions[..., :20].argmax(-1).unsqueeze(-1)
best_confidence = torch.max(predictions[..., 20], predictions[..., 25]).unsqueeze(
-1
)
converted_preds = torch.cat(
(predicted_class, best_confidence, converted_bboxes), dim=-1
)
return converted_preds
#convert cellboxes to box for whole image
def cellboxes_to_boxes(out, S=7):
converted_pred = convert_cellboxes(out).reshape(out.shape[0], S * S, -1)
converted_pred[..., 0] = converted_pred[..., 0].long()
all_bboxes = []
for ex_idx in range(out.shape[0]):
bboxes = []
for bbox_idx in range(S * S):
bboxes.append([x.item() for x in converted_pred[ex_idx, bbox_idx, :]])
all_bboxes.append(bboxes)
return all_bboxes
#for saving checkpoint of image
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)
#for loading the same checkpoint
def load_checkpoint(checkpoint, model, optimizer):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
class YoloLoss(nn.Module):
def __init__(self, S=7, B=2, C=20):
super(YoloLoss, self).__init__()
self.mse = nn.MSELoss(reduction="sum")
"""
S is split size of image (in paper 7),
B is number of boxes (in paper 2),
C is number of classes (in paper and VOC dataset is 20),
"""
self.S = S
self.B = B
self.C = C
# These are from Yolo paper, signifying how much we should
# pay loss for no object (noobj) and the box coordinates (coord)
self.lambda_noobj = 0.5
self.lambda_coord = 5
def forward(self, predictions, target):
# predictions are shaped (BATCH_SIZE, S*S(C+B*5) when inputted
predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B * 5)
# Calculate IoU for the two predicted bounding boxes with target bbox
# 0-19 will be class probabilities, 20 will be class score, 21-25 will be bounding values and box conv1, 26-30 will be bounding values and box conv2
iou_b1 = intersection_over_union(predictions[..., 21:25], target[..., 21:25])
iou_b2 = intersection_over_union(predictions[..., 26:30], target[..., 21:25])
ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
# Take the box with highest IoU out of the two prediction
# Note that bestbox will be indices of 0, 1 for which bbox was best
iou_maxes, bestbox = torch.max(ious, dim=0)
exists_box = target[..., 20].unsqueeze(3) # in paper this is Iobj_i
# ======================== #
# FOR BOX COORDINATES #
# ======================== #
# Set boxes with no object in them to 0. We only take out one of the two
# predictions, which is the one with highest Iou calculated previously.
box_predictions = exists_box * (
(
bestbox * predictions[..., 26:30]
+ (1 - bestbox) * predictions[..., 21:25]
)
)
box_targets = exists_box * target[..., 21:25]
# Take sqrt of width, height of boxes to ensure that
box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt(
torch.abs(box_predictions[..., 2:4] + 1e-6)
)
box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])
box_loss = self.mse(
torch.flatten(box_predictions, end_dim=-2),
torch.flatten(box_targets, end_dim=-2),
)
# ==================== #
# FOR OBJECT LOSS #
# ==================== #
# pred_box is the confidence score for the bbox with highest IoU
pred_box = (
bestbox * predictions[..., 25:26] + (1 - bestbox) * predictions[..., 20:21]
)
object_loss = self.mse(
torch.flatten(exists_box * pred_box),
torch.flatten(exists_box * target[..., 20:21]),
)
# ======================= #
# FOR NO OBJECT LOSS #
# ======================= #
#(N, S, S, 1) > (N, S*S)
no_object_loss = self.mse(
torch.flatten((1 - exists_box) * predictions[..., 20:21], start_dim=1),
torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
)
no_object_loss += self.mse(
torch.flatten((1 - exists_box) * predictions[..., 25:26], start_dim=1),
torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1)
)
# ================== #
# FOR CLASS LOSS #
# ================== #
#(N, S, S, 20) > ( N*N*S , 20)
class_loss = self.mse(
torch.flatten(exists_box * predictions[..., :20], end_dim=-2,),
torch.flatten(exists_box * target[..., :20], end_dim=-2,),
)
loss = (
self.lambda_coord * box_loss # first two rows in paper
+ object_loss # third row in paper
+ self.lambda_noobj * no_object_loss # forth row
+ class_loss # fifth row
)
return loss
#Common for Kaggle, Colab and Local, send paths for directories and train csv accordingly
class VOCDataset(torch.utils.data.Dataset):
def __init__(
self, csv_file, img_dir, label_dir, S=7, B=2, C=20, transform=None,
):
self.annotations = pd.read_csv(csv_file)
self.img_dir = img_dir
self.label_dir = label_dir
self.transform = transform
self.S = S
self.B = B
self.C = C
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
label_path = os.path.join(self.label_dir, self.annotations.iloc[index, 1])
boxes = []
with open(label_path) as f:
for label in f.readlines():
class_label, x, y, width, height = [
float(x) if float(x) != int(float(x)) else int(x)
for x in label.replace("\n", "").split()
]
boxes.append([class_label, x, y, width, height])
img_path = os.path.join(self.img_dir, self.annotations.iloc[index, 0])
image = Image.open(img_path)
boxes = torch.tensor(boxes)
if self.transform:
# image = self.transform(image)
image, boxes = self.transform(image, boxes)
# Convert To Cells
label_matrix = torch.zeros((self.S, self.S, self.C + 5 * self.B))
for box in boxes:
class_label, x, y, width, height = box.tolist()
class_label = int(class_label)
# i,j represents the cell row and cell column
i, j = int(self.S * y), int(self.S * x)
x_cell, y_cell = self.S * x - j, self.S * y - i
#Calculating the width and height of cell of bounding box,relative to the cell
width_cell, height_cell = (
width * self.S,
height * self.S,
)
# If no object already found for specific cell i,j
if label_matrix[i, j, 20] == 0:
# Set that there exists an object
label_matrix[i, j, 20] = 1
# Box coordinates
box_coordinates = torch.tensor(
[x_cell, y_cell, width_cell, height_cell]
)
label_matrix[i, j, 21:25] = box_coordinates
label_matrix[i, j, class_label] = 1
return image, label_matrix
#user-defined class for transforms
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img, bboxes):
for t in self.transforms:
img, bboxes = t(img), bboxes
return img, bboxes
transform = Compose([transforms.Resize((448, 448)), transforms.ToTensor(),]) #normalization can also be added
def train_fn(train_loader, model, optimizer, loss_fn):
loop = tqdm(train_loader, leave=True) #for progress bar
mean_loss = []
for batch_idx, (x, y) in enumerate(loop):
x, y = x.to(DEVICE), y.to(DEVICE)
out = model(x)
loss = loss_fn(out, y)
mean_loss.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update progress bar
loop.set_postfix(loss=loss.item())
print(f"Mean loss : {sum(mean_loss)/len(mean_loss)}")
model = Yolov1(split_size=7, num_boxes=2, num_classes=20).to(DEVICE)
optimizer = optim.Adam(
model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
loss_fn = YoloLoss()
if LOAD_MODEL:
load_checkpoint(torch.load(LOAD_MODEL_FILE), model, optimizer)
train_dataset = VOCDataset(
HUNDRED_EXAMPLE_DATASET_FILEPATH,
transform=transform,
img_dir=IMG_DIR,
label_dir=LABEL_DIR,
)
test_dataset = VOCDataset(
TEST_DATASET_FILEPATH, transform=transform, img_dir=IMG_DIR, label_dir=LABEL_DIR
)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
pin_memory=PIN_MEMORY,
shuffle=True,
drop_last=True,
)
print("Training set loaded...")
test_loader = DataLoader(
dataset=test_dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
pin_memory=PIN_MEMORY,
shuffle=True,
drop_last=True,
)
print("Testing set loaded...")
best_acc = 0.0
if IS_TRAIN:
start_training = datetime.datetime.now()
for epoch in range(EPOCHS):
print(f"Running epoch : {epoch + 1}/{EPOCHS}")
pred_boxes_train, target_boxes_train = get_bboxes(
train_loader, model, iou_threshold=0.5, threshold=0.4
)
mean_avg_prec_train = mean_average_precision(
pred_boxes_train, target_boxes_train, iou_threshold=0.5, box_format="midpoint"
)
print(f"Train mAP: {mean_avg_prec_train}")
if mean_avg_prec > best_acc:
best_acc = mean_avg_prec
# if mean_avg_prec >0.9 and mean_avg_prec == best_acc:
# checkpoint = {
# "state_dict": model.state_dict(),
# "optimizer": optimizer.state_dict(),
# }
# save_checkpoint(checkpoint, filename=LOAD_MODEL_FILE)
# time.sleep(10)
train_fn(train_loader, model, optimizer, loss_fn)
print(f"Total training time {(datetime.datetime.now()-start_training).seconds/60} minutes.")
#for testing
# load_checkpoint(torch.load(LOAD_MODEL_FILE), model, optimizer)
for x, y in train_loader:
x = x.to(DEVICE)
for idx in range(2):
bboxes_pred = cellboxes_to_boxes(model(x))
bboxes_pred = non_max_suppression(bboxes_pred[idx], iou_threshold=0.5, threshold=0.4, box_format="midpoint")
plot_image(x[idx].permute(1,2,0).to("cpu"), bboxes_pred)