This notebook shows how to train SegNet on Pascal VOC2007 dataset.
This notebook and segnet.py file are based on code from https://github.com/Sayan98/pytorch-segnet
My contributions include:
References
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import PIL
import torch
import segnet # segnet.py contains model definition
Acquire GPU device if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Download VOC2007 dataset from link in references. File should have name VOCtrainval_06-Nov-2007.tar and weight approx 460MB. Extract VOC2007 folder and modify path below.
# Dataset v3
num_classes = 22 # background, airplane, ..., border
data_root = '/home/marcin/Datasets/VOC2007' # Dataset location
batch_size = 16 # Mini-batch size
Class names, for reference only
# for reference, not used in this notebook
voc_classes = ('background', # always index 0
'aeroplane', 'bicycle', 'bird', 'boat', # indices 1, 2, 3, 4
'bottle', 'bus', 'car', 'cat', 'chair', # 5, ...
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor', # ..., 21
'border') # but border has index 255 (!)
assert num_classes == len(voc_classes)
Class below reads segmentation dataset in VOC2007 compatible format.
class PascalVOCDataset(torch.utils.data.Dataset):
"""Pascal VOC2007 or compatible dataset"""
def __init__(self, num_classes, list_file, img_dir, mask_dir, transform=None):
self.num_classes = num_classes
self.images = open(list_file, "rt").read().split("\n")[:-1]
self.transform = transform
self.img_extension = ".jpg"
self.mask_extension = ".png"
self.image_root_dir = img_dir
self.mask_root_dir = mask_dir
self.counts = self.__compute_class_probability()
def __len__(self):
return len(self.images)
def __getitem__(self, index):
name = self.images[index]
image_path = os.path.join(self.image_root_dir, name + self.img_extension)
mask_path = os.path.join(self.mask_root_dir, name + self.mask_extension)
image = self.load_image(path=image_path)
gt_mask = self.load_mask(path=mask_path)
return torch.FloatTensor(image), torch.LongTensor(gt_mask)
def __compute_class_probability(self):
counts = dict((i, 0) for i in range(self.num_classes))
for name in self.images:
mask_path = os.path.join(self.mask_root_dir, name + self.mask_extension)
raw_image = PIL.Image.open(mask_path).resize((224, 224))
imx_t = np.array(raw_image).reshape(224*224)
imx_t[imx_t==255] = self.num_classes-1 # convert VOC border into last class
for i in range(self.num_classes):
counts[i] += np.sum(imx_t == i)
return counts
def get_class_probability(self):
values = np.array(list(self.counts.values()))
p_values = values/np.sum(values)
return torch.Tensor(p_values)
def load_image(self, path=None):
raw_image = PIL.Image.open(path)
raw_image = np.transpose(raw_image.resize((224, 224)), (2,1,0))
imx_t = np.array(raw_image, dtype=np.float32)/255.0
return imx_t
def load_mask(self, path=None):
raw_image = PIL.Image.open(path)
raw_image = raw_image.resize((224, 224))
imx_t = np.array(raw_image)
imx_t[imx_t==255] = self.num_classes-1 # convert VOC border into last class
return imx_t
Dataset internal structure, no need to modify this
train_path = os.path.join(data_root, 'ImageSets/Segmentation/train.txt')
val_path = os.path.join(data_root, 'ImageSets/Segmentation/val.txt')
img_dir = os.path.join(data_root, "JPEGImages")
mask_dir = os.path.join(data_root, "SegmentationClass")
Create train and validation datasets
train_dataset = PascalVOCDataset(num_classes=num_classes, list_file=train_path,
img_dir=img_dir, mask_dir=mask_dir)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=4)
val_dataset = PascalVOCDataset(num_classes=num_classes, list_file=val_path,
img_dir=img_dir, mask_dir=mask_dir)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
shuffle=True, num_workers=4)
Dataset statistics
print('Train Dataset:')
print(' length:', len(train_dataset))
print(' classes:', train_dataset.counts)
print()
print('Validation Dataset:')
print(' length:', len(val_dataset))
print(' classes:', val_dataset.counts)
Show example image and mask
image, mask = train_dataset[11]
image.transpose_(0, 2)
print('image shape:', list(image.shape))
print('mask shape: ', list(mask.shape))
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(8,4))
ax1.imshow(image)
ax2.imshow(mask)
plt.show()
Construct mini-batch of test images. This is a one-element mini-batch used to plot model predictions as it trains.
test_image_idx = 11 # only one image, but could be more
test_images = train_dataset[test_image_idx][0].unsqueeze(0)
test_masks = train_dataset[test_image_idx][1].unsqueeze(0)
Helper for plotting, used a bit later
def plot_test_batch(test_images, test_logits, test_masks):
for i in range(len(test_images)):
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=1, ncols=4,
figsize=(12, 4*len(test_images)))
ax1.imshow(test_images[i].transpose(0, 2).numpy()); ax1.set_title('Input Image')
predicted_mx = test_logits[i].numpy()
predicted_mx = predicted_mx.argmax(axis=0)
ax2.imshow(predicted_mx, vmin=0, vmax=num_classes); ax2.set_title('Predicted Mask')
target_mx = test_masks[i].numpy()
ax3.imshow(target_mx, vmin=0, vmax=num_classes); ax3.set_title('Ground Truth')
acc_mx = predicted_mx != target_mx
ax4.imshow(acc_mx); ax4.set_title('Prediction Errors')
plt.show()
save_model_path = './models/model_best.pth'
learning_rate = 1e-6 # 1e-6
Create directory to save model to
dir_path = os.path.split(save_model_path)[0]
if not os.path.exists(dir_path):
os.makedirs(dir_path)
Create model
model = segnet.SegNet(input_channels=3, output_channels=num_classes).to(device)
class_weights = 1.0 / train_dataset.get_class_probability().to(device)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
Optionally test model on randomly generated input mini-batch.
test_model_on_random_input = False
if test_model_on_random_input:
img = torch.randn([4, 3, 224, 224]).to(device)
logits, probabilities = model(img)
print(logits.size())
print(probabilities.size())
print(logits[0,:,0,0])
print(probabilities[0,:,0,0].sum())
Test model on test_images
with torch.no_grad():
test_logits, _ = model(test_images.to(device)) # pass through model
test_logits = test_logits.cpu()
plot_test_batch(test_images, test_logits, test_masks)
Helper for plotting
def plot_trace(trace):
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 4))
ax1.plot(trace['tloss'], label='train loss')
ax1.plot(trace['vloss'], label='valid loss')
ax1.set_xlabel('Epoch')
ax1.legend()
ax2.plot(trace['tacc'], label='train acc')
ax2.plot(trace['vacc'], label='valid acc')
ax2.set_xlabel('Epoch')
ax2.legend()
plt.show()
Calc accuracy, returns tensor
def accuracy(logits, labels):
predictions = torch.argmax(logits, dim=1)
return (predictions == labels).float().mean() # tensor!!
Helper function for training
def train(nb_epochs, trace):
prev_loss = float('inf')
epoch = len(trace['epoch'])
for _ in range(nb_epochs):
time_start = time.time()
#
# Train Model
#
model.train()
tloss_sum, tacc_sum = 0, 0
for inputs, targets in train_dataloader:
inputs = inputs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
logits, _ = model(inputs)
loss = criterion(logits, targets)
loss.backward()
optimizer.step()
with torch.no_grad():
acc = accuracy(logits, targets)
tloss_sum += loss.item() * len(inputs)
tacc_sum += acc.item() * len(inputs)
tloss_avg = tloss_sum / len(train_dataset)
tacc_avg = tacc_sum / len(train_dataset)
#
# Evaluate Model
#
model.eval()
with torch.no_grad():
vloss_sum, vacc_sum = 0, 0
for inputs, targets in val_dataloader:
inputs = inputs.to(device)
targets = targets.to(device)
logits, _ = model(inputs)
loss = criterion(logits, targets)
acc = accuracy(logits, targets)
vloss_sum += loss.item() * len(inputs)
vacc_sum += acc.item() * len(inputs)
vloss_avg = vloss_sum / len(val_dataset)
vacc_avg = vacc_sum / len(val_dataset)
#
# Logging
#
time_delta = time.time() - time_start
trace['epoch'].append(epoch)
trace['tloss'].append(tloss_avg)
trace['tacc'].append(tacc_avg)
trace['vloss'].append(vloss_avg)
trace['vacc'].append(vacc_avg)
if vloss_avg < prev_loss:
prev_loss = vloss_avg
if save_model_path is not None:
torch.save(model.state_dict(), save_model_path)
print(f'Epoch: {epoch:3} T/V Loss: {tloss_avg:.4f} / {vloss_avg:.4f} '
f'T/V Acc: {tacc_avg:.4f} / {vacc_avg:.4f} Time: {time_delta:.2f}s')
if (epoch+1) % 10 == 0:
plot_trace(trace)
with torch.no_grad():
test_logits, _ = model(test_images.to(device)) # pass through model
test_logits = test_logits.cpu()
plot_test_batch(test_images, test_logits, test_masks)
epoch += 1
Do actual training
trace = {'epoch': [], 'tloss': [], 'vloss': [], 'tacc': [], 'vacc': []}
train(nb_epochs=20, trace=trace)
train(nb_epochs=20, trace=trace)
train(nb_epochs=20, trace=trace)
train(nb_epochs=20, trace=trace)
train(nb_epochs=20, trace=trace)
Optionally load pre-trained model
# pretrained_model_path = './models/model_best.pth'
# model.load_state_dict(torch.load(pretrained_model_path))
def validate():
model.eval()
with torch.no_grad():
for inputs, targets in val_dataloader:
logits, _ = model(inputs.to(device))
logits = logits.cpu()
plot_test_batch(inputs, logits, targets)
This will load every single image from validation set and will plot model prediction against ground truth
print('Current Epoch', len(trace['epoch']))
validate()
# TODO: Change me
VOC_CLASSES = ('background', # always index 0
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
NUM_CLASSES = len(VOC_CLASSES) + 1
NUM_CLASSES
VOC_CLASSES = ('background', # always index 0
'pipe')
NUM_CLASSES = len(VOC_CLASSES)
# TODO: Change me
class PascalVOCDataset(torch.utils.data.Dataset):
"""Pascal VOC 2007 Dataset"""
def __init__(self, list_file, img_dir, mask_dir, transform=None):
self.images = open(list_file, "rt").read().split("\n")[:-1]
self.transform = transform
self.img_extension = ".jpg"
self.mask_extension = ".png"
self.image_root_dir = img_dir
self.mask_root_dir = mask_dir
self.counts = self.__compute_class_probability()
def __len__(self):
return len(self.images)
def __getitem__(self, index):
name = self.images[index]
image_path = os.path.join(self.image_root_dir, name + self.img_extension)
mask_path = os.path.join(self.mask_root_dir, name + self.mask_extension)
image = self.load_image(path=image_path)
gt_mask = self.load_mask(path=mask_path)
data = {
'image': torch.FloatTensor(image),
'mask' : torch.LongTensor(gt_mask)
}
return data
def __compute_class_probability(self):
counts = dict((i, 0) for i in range(NUM_CLASSES))
for name in self.images:
mask_path = os.path.join(self.mask_root_dir, name + self.mask_extension)
raw_image = PIL.Image.open(mask_path).resize((224, 224))
imx_t = np.array(raw_image).reshape(224*224)
imx_t[imx_t==255] = len(VOC_CLASSES)
for i in range(NUM_CLASSES):
counts[i] += np.sum(imx_t == i)
return counts
def get_class_probability(self):
values = np.array(list(self.counts.values()))
p_values = values/np.sum(values)
return torch.Tensor(p_values)
def load_image(self, path=None):
raw_image = PIL.Image.open(path)
raw_image = np.transpose(raw_image.resize((224, 224)), (2,1,0))
imx_t = np.array(raw_image, dtype=np.float32)/255.0
return imx_t
def load_mask(self, path=None):
raw_image = PIL.Image.open(path)
raw_image = raw_image.resize((224, 224))
imx_t = np.array(raw_image)
# border
imx_t[imx_t==255] = len(VOC_CLASSES)
return imx_t
data_root = '/home/marcin/Datasets/VOC2007'
train_txt, val_txt = 'train_mini.txt', 'val_mini.txt'
data_root = '/home/marcin/Datasets/rovco/dataset'
train_txt, val_txt = 'train.txt', 'val.txt'
train_path = os.path.join(data_root, 'ImageSets/Segmentation', train_txt)
val_path = os.path.join(data_root, 'ImageSets/Segmentation/', val_txt)
img_dir = os.path.join(data_root, "JPEGImages")
mask_dir = os.path.join(data_root, "SegmentationClass")
save_dir = './savedir'
checkpoint = None
CUDA = True # args.gpu is not None
GPU_ID = 0 # args.gpu
BATCH_SIZE = 16
train_dataset = PascalVOCDataset(list_file=train_path,
img_dir=img_dir,
mask_dir=mask_dir)
train_dataloader = torch.utils.data.DataLoader(train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4)
print(train_dataset.get_class_probability())
sample = train_dataset[11]
image, mask = sample['image'], sample['mask']
image.transpose_(0, 2)
fig = plt.figure()
a = fig.add_subplot(1,2,1)
plt.imshow(image)
a = fig.add_subplot(1,2,2)
plt.imshow(mask)
plt.show()
test_model_on_random_input = False
if test_model_on_random_input:
# Model
model = segnet.SegNet(input_channels=3, output_channels=NUM_CLASSES)
# print(model)
img = torch.randn([4, 3, 224, 224])
output, softmaxed_output = model(img)
print(output.size())
print(softmaxed_output.size())
print(output[0,:,0,0])
print(softmaxed_output[0,:,0,0].sum())
# TODO: Change me
# Constants
NUM_INPUT_CHANNELS = 3
NUM_OUTPUT_CHANNELS = NUM_CLASSES
NUM_EPOCHS = 6000
LEARNING_RATE = 1e-3 # 1e-6
MOMENTUM = 0.9
if CUDA:
model = segnet.SegNet(input_channels=NUM_INPUT_CHANNELS,
output_channels=NUM_OUTPUT_CHANNELS).cuda(GPU_ID)
class_weights = 1.0/train_dataset.get_class_probability().cuda(GPU_ID)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights).cuda(GPU_ID)
else:
model = segnet.SegNet(input_channels=NUM_INPUT_CHANNELS,
output_channels=NUM_OUTPUT_CHANNELS)
class_weights = 1.0/train_dataset.get_class_probability()
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
if checkpoint:
model.load_state_dict(torch.load(args.checkpoint))
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
def train():
is_better = True
prev_loss = float('inf')
model.train()
for epoch in range(NUM_EPOCHS):
loss_f = 0
t_start = time.time()
for batch in train_dataloader:
input_tensor = torch.autograd.Variable(batch['image'])
target_tensor = torch.autograd.Variable(batch['mask'])
if CUDA:
input_tensor = input_tensor.cuda(GPU_ID)
target_tensor = target_tensor.cuda(GPU_ID)
predicted_tensor, softmaxed_tensor = model(input_tensor)
optimizer.zero_grad()
loss = criterion(softmaxed_tensor, target_tensor)
loss.backward()
optimizer.step()
loss_f += loss.float()
prediction_f = softmaxed_tensor.float()
delta = time.time() - t_start
is_better = loss_f < prev_loss
if is_better:
prev_loss = loss_f
torch.save(model.state_dict(), os.path.join(save_dir, "model_best.pth"))
print("Epoch #{}\tLoss: {:.8f}\t Time: {:2f}s".format(epoch+1, loss_f, delta))
train()
def validate():
model.eval()
for batch_idx, batch in enumerate(val_dataloader):
input_tensor = torch.autograd.Variable(batch['image'])
target_tensor = torch.autograd.Variable(batch['mask'])
if CUDA:
input_tensor = input_tensor.cuda(GPU_ID)
target_tensor = target_tensor.cuda(GPU_ID)
predicted_tensor, softmaxed_tensor = model(input_tensor)
loss = criterion(predicted_tensor, target_tensor)
for idx, predicted_mask in enumerate(softmaxed_tensor):
target_mask = target_tensor[idx]
input_image = input_tensor[idx]
fig = plt.figure()
a = fig.add_subplot(1,3,1)
plt.imshow(input_image.transpose(0, 2).cpu().numpy())
a.set_title('Input Image')
a = fig.add_subplot(1,3,2)
predicted_mx = predicted_mask.detach().cpu().numpy()
predicted_mx = predicted_mx.argmax(axis=0)
plt.imshow(predicted_mx)
a.set_title('Predicted Mask')
a = fig.add_subplot(1,3,3)
target_mx = target_mask.detach().cpu().numpy()
plt.imshow(target_mx)
a.set_title('Ground Truth')
#fig.savefig(os.path.join(OUTPUT_DIR, "prediction_{}_{}.png".format(batch_idx, idx)))
#plt.close(fig)
SAVED_MODEL_PATH = './savedir/model_best.pth'
val_dataset = PascalVOCDataset(list_file=val_path,
img_dir=img_dir,
mask_dir=mask_dir)
val_dataloader = torch.utils.data.DataLoader(val_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=4)
# if CUDA:
# model = segnet.SegNet(input_channels=NUM_INPUT_CHANNELS,
# output_channels=NUM_OUTPUT_CHANNELS).cuda(GPU_ID)
# class_weights = 1.0/val_dataset.get_class_probability().cuda(GPU_ID)
# criterion = torch.nn.CrossEntropyLoss(weight=class_weights).cuda(GPU_ID)
# else:
# model = segnet.SegNet(input_channels=NUM_INPUT_CHANNELS,
# output_channels=NUM_OUTPUT_CHANNELS)
# class_weights = 1.0/val_dataset.get_class_probability()
# criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
# model.load_state_dict(torch.load(SAVED_MODEL_PATH))
validate()