import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, models
from collections import OrderedDict
import PIL
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataset_location = '/home/marcin/Datasets/udacity-challange-flower-data/flower_data/'
Define transforms
imgnet_mean, imgnet_std = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225])
transforms_train = transforms.Compose([
transforms.Resize(256),
transforms.Pad(100, padding_mode='reflect'),
transforms.RandomRotation(45),
transforms.CenterCrop(256),
transforms.RandomResizedCrop(224, scale=(0.8 , 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(imgnet_mean, imgnet_std)
])
# transforms_list = transforms.Compose([
# transforms.Resize(256),
# transforms.CenterCrop(224),
# transforms.ToTensor(),
# transforms.Normalize(imgnet_mean, imgnet_std)
# ])
transforms_valid = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
Test transforms
def tensor_img_2_numpy(tensor_img):
ttt = transforms.functional.normalize(tensor_img, -imgnet_mean/imgnet_std, 1/imgnet_std)
return transforms.functional.to_pil_image(ttt)
img = PIL.Image.open(os.path.join(dataset_location, 'train/1/image_06734.jpg'))
fig, axes = plt.subplots(ncols=6, figsize=[16,4])
axes[0].set_title('Original')
axes[0].imshow(img)
axes[1].set_title('Validation')
tensor_img = transforms_valid(img)
axes[1].imshow(tensor_img_2_numpy(tensor_img))
for i in range(2, len(axes)):
axes[i].set_title(f'Train #{i-2}')
tensor_img = transforms_train(img)
axes[i].imshow(tensor_img_2_numpy(tensor_img))
Create Dataloaders
dataset_train = datasets.ImageFolder(os.path.join(dataset_location, 'train'), transforms_train)
dataset_valid = datasets.ImageFolder(os.path.join(dataset_location, 'valid'), transforms_valid)
print('Number of train images:', len(dataset_train))
print('Number of valid images:', len(dataset_valid))
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=16, shuffle=True,
num_workers=6, pin_memory=True)
dataloader_valid = torch.utils.data.DataLoader(dataset_valid, batch_size=16, shuffle=True,
num_workers=6, pin_memory=True)
model = models.densenet121(pretrained=True)
# Freeze feature paremeters
for param in model.parameters():
param.requires_grad = False
# Replace classifier
classifier = nn.Sequential(OrderedDict([
('fc1', nn.Linear(1024, 512)),
('relu', nn.ReLU()),
('dropout', nn.Dropout(0.2)),
('fc2', nn.Linear(512, 102)),
('output', nn.LogSoftmax(dim=1))
]))
model.classifier = classifier
model = model.to(device)
criterion = nn.NLLLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.001)
# model = models.resnet34(pretrained=True)
# # Freeze feature paremeters
# # for param in model.parameters():
# # param.requires_grad = False
# # Replace classifier
# classifier = nn.Sequential(OrderedDict([
# ('fc2', nn.Linear(512, 102)),
# ('output', nn.LogSoftmax(dim=1))
# ]))
# model.fc = classifier
# model = model.to(device)
# criterion = nn.NLLLoss()
# optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)
Helper
def accuracy(logits, labels):
predictions = torch.argmax(logits, dim=1)
return (predictions == labels).float().mean() # tensor!!
Train model
num_epochs = 20
batch_size = 64
since = time.time()
hist = { 'tepoch':[], 'tloss':[], 'tacc':[], # mini-batch loss/acc every iteration
'vepoch':[], 'vloss':[], 'vacc':[],
'train_loss':[], 'train_acc':[], # train set loss/acc every epoch
'valid_loss':[], 'valid_acc':[] } # valid set loss/acc every epoch
for epoch in range(num_epochs):
epoch_time_start = time.time()
### Train ###
model.train()
loss_sum = 0
acc_sum = 0
for images, labels in dataloader_train:
# Push to GPU
x = images.to(device)
y = labels.to(device)
# Optimize
optimizer.zero_grad()
outputs = model(x)
loss = criterion(outputs, y)
loss.backward()
optimizer.step()
# Record per-iteration stats
with torch.no_grad():
acc = accuracy(outputs, y)
loss_sum += loss.item() * len(images)
acc_sum += acc.item() * len(images)
hist['tepoch'].append( epoch )
hist['tacc'].append( acc.item() )
hist['tloss'].append( loss.item() )
hist['train_loss'].append( loss_sum / len(dataset_train) )
hist['train_acc'].append( acc_sum / len(dataset_train) )
### Evaluate ###
model.eval()
loss_sum = 0
acc_sum = 0
for images, labels in dataloader_valid:
# Push to GPU
x = images.to(device)
y = labels.to(device)
with torch.no_grad():
outputs = model(x)
loss = criterion(outputs, y)
acc = accuracy(outputs, y)
loss_sum += loss.item() * len(images)
acc_sum += acc.item() * len(images)
hist['vepoch'].append( epoch )
hist['vloss'].append( loss.item() )
hist['vacc'].append( acc.item() )
hist['valid_loss'].append( loss_sum / len(dataset_valid) )
hist['valid_acc'].append( acc_sum / len(dataset_valid) )
epoch_time_interval = time.time() - epoch_time_start
### Print Summary ###
if epoch == 0:
print(' (time ) ep loss / acc loss / acc')
print(f'Epoch ({epoch_time_interval:4.2f}s): {epoch:3}'
f' Train: {hist["train_loss"][-1]:6.4f} / {hist["train_acc"][-1]:6.4f}'
f' Valid: {hist["valid_loss"][-1]:6.4f} / {hist["valid_acc"][-1]:6.4f}')
print(time.time() - since)
def pretty_plot(ax, data, label, color, alpha=1.0):
def smooth(y, n):
return np.convolve(y, v=np.ones(n)/n, mode='same')
ax.scatter(range(len(data)), data, marker='.', s=2, color=color)
ax.plot(smooth(data, 20), label=label, color=color)
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=[16,6])
pretty_plot(ax, hist['tloss'], 'tloss', 'blue')
pretty_plot(ax.twiny(), hist['vloss'], 'vloss', 'red')
ax.set_title('Train Loss Every Iteration');
ax.grid();
def plot_hist(hist, title):
fig, (ax, ax2) = plt.subplots(nrows=1, ncols=2, figsize=[16,3])
fig.suptitle(title, fontsize=16)
#ax.plot(hist['train_loss'], label='train_loss', color='blue')
pretty_plot(ax.twiny(), hist['tloss'], 'tloss', color='blue', alpha=.5)
ax.plot(hist['valid_loss'], label='valid_loss', color='orange')
ax.set_title('Loss'); ax.legend(); ax.grid(); ax.set_ylim([0, 1]);
#fig, ax = plt.subplots(nrows=1, ncols=1, figsize=[16,3])
ax2.plot(hist['train_acc'], label='train_acc', color='blue')
#pretty_plot(ax2.twiny(), hist['tacc'], 'tacc', color='blue', alpha=1)
ax2.plot(hist['valid_acc'], label='valid_acc', color='orange')
ax2.set_title('Accuracy'); ax2.legend(); ax2.grid(); ax2.set_ylim([.8, 1]);
plt.tight_layout()
plot_hist(hist, title='DenseNet121 aug')
plot_hist(hist, title='ResNet34 aug')
plot_hist(hist, title='ResNet34 no-aug')