Contents
References
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# from torchvision import datasets
Download dataset
# dataset = datasets.SVHN(root='~/.pytorch/svhn/', split='train', download=True, transform=None)
Because this is very small dataset we will load it fully into GPU. Otherwise dataloader will slow learning considerably.
# x_train = (dataset.data.astype(np.float32) / 127.5) - 1 # scale to range [-1..1]
# x_train = x_train[:-41] # make divisible by 512
# print(x_train.shape)
Visualize the data
# fig, axes = plt.subplots(nrows=1, ncols=8, figsize=[16,6])
# for i in range(len(axes)):
# idx = np.random.randint(0, len(x_train))
# img = x_train[idx]
# img = img/2 + .5 # -1..1 -> 0..1
# img = img.transpose(1, 2, 0)
# axes[i].imshow(img); axes[i].axis('off')
Move to GPU if available
# x_train = torch.tensor(x_train, device=device)
Point this to folder with CelebA unzipped
dataset_location = '/home/marcin/Datasets/img_align_celeba'
datafile = os.path.join(dataset_location, 'img_align_celeba_32x32.npz')
print(datafile)
import PIL
Get names of all images in CelebA data folder
all_files = os.listdir(os.path.join(dataset_location, 'img_align_celeba'))
Images in CelebA are pre-aligned such that eyes and mouth location on different images is almost identical. We will use this fact and hardcode cropping to 128x128 and rescale to 32x32
def crop_celeba(img, size):
assert img.size == (178, 218)
box = [25, 65, img.width-25, img.height-25] # crop 25/65/25/25 from left/top/right/bottom
return img.crop(box) # result.size = (128, 128)
Show couple images with and without cropping
fig, (axes1, axes2) = plt.subplots(nrows=2, ncols=6, figsize=[16,6])
for i in range(len(axes1)):
idx = i # np.random.randint(0, len(all_files))
img_full_path = os.path.join(dataset_location, 'img_align_celeba', all_files[idx])
img = PIL.Image.open(img_full_path)
img_128x128 = crop_celeba(img, None)
img_64x64 = img_128x128.resize([64, 64], PIL.Image.BICUBIC)
axes1[i].imshow(img)
axes1[i].set_title('original')
axes2[i].imshow(img_64x64)
axes2[i].set_title('processed')
This function will read, crop, scale all images and put them into one large numpy array
def load_all_images(dataset_location, all_files, size=32):
all_images = []
for i in range(len(all_files)):
img_full_path = os.path.join(dataset_location, 'img_align_celeba', all_files[i])
img = PIL.Image.open(img_full_path)
img_128x128 = crop_celeba(img, None)
img_small = img_128x128.resize([size, size], PIL.Image.BICUBIC)
arr = np.array(img_small)
assert arr.shape == (size, size, 3)
all_images.append(arr)
if i % 10000 == 0:
print(f'Image {i} of {len(all_files)}')
return np.array(all_images)
Check if dataset was processed earlier, if not then process and save to .npz, if yes then just load from file
if not os.path.isfile(datafile):
print('Creating datafile...')
all_images = load_all_images(dataset_location, all_files, size=32)
np.savez(datafile, all_images=all_images)
else:
print('Loading from file...')
npzfile = np.load(datafile)
all_images = npzfile['all_images']
Convert to float32, scale to -1..1 range and re-order dimensions to match pytorch convention
x_train = (all_images.astype(np.float32) / 127.5) - 1 # scale to range [-1..1]
x_train = x_train.transpose(0, 3, 1, 2) # pytorch ordering NCHW
x_train = x_train[:-103] # make divisible by 128
x_train.shape
del all_images # save memory
Move everything to GPU if available (approx 2.5GB)
x_train = torch.tensor(x_train, device=device)
Discriminator
class Discriminator(nn.Module):
def __init__(self, conv_dim=32):
super(Discriminator, self).__init__()
self.L1_conv = nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1, bias=False)
self.L1_act = nn.LeakyReLU(0.2)
self.L2_conv = nn.Conv2d(conv_dim, 2*conv_dim, kernel_size=4, stride=2, padding=1, bias=False)
self.L2_bn = nn.BatchNorm2d(2*conv_dim)
self.L2_act = nn.LeakyReLU(0.2)
self.L3_conv = nn.Conv2d(2*conv_dim, 4*conv_dim, kernel_size=4, stride=2, padding=1, bias=False)
self.L3_bn = nn.BatchNorm2d(4*conv_dim)
self.L3_act = nn.LeakyReLU(0.2)
self.L4_fc = nn.Linear(in_features=4*conv_dim*4*4, out_features=1)
def forward(self, x):
x = self.L1_conv(x)
x = self.L1_act(x)
x = self.L2_conv(x)
x = self.L2_bn(x)
x = self.L2_act(x)
x = self.L3_conv(x)
x = self.L3_bn(x)
x = self.L3_act(x)
x = x.view(x.size(0), -1) # flatten
x = self.L4_fc(x)
return x # return logits
Generator
class Generator(nn.Module):
def __init__(self, z_size, conv_dim=32):
super(Generator, self).__init__()
self.conv_dim = conv_dim
self.L0_fc = nn.Linear(in_features=z_size, out_features=4*conv_dim*4*4)
self.L1_convT = nn.ConvTranspose2d(4*conv_dim, 2*conv_dim, kernel_size=4, stride=2, padding=1, bias=False)
self.L1_bn = nn.BatchNorm2d(2*conv_dim)
self.L1_act = nn.ReLU()
self.L2_convT = nn.ConvTranspose2d(2*conv_dim, conv_dim, kernel_size=4, stride=2, padding=1, bias=False)
self.L2_bn = nn.BatchNorm2d(conv_dim)
self.L2_act = nn.ReLU()
self.L3_convT = nn.ConvTranspose2d(conv_dim, 3, kernel_size=4, stride=2, padding=1, bias=False)
self.L3_act = nn.Tanh()
def forward(self, x):
x = self.L0_fc(x)
x = x.view(x.size(0), 4*self.conv_dim, 4, 4)
x = self.L1_convT(x)
x = self.L1_bn(x)
x = self.L1_act(x)
x = self.L2_convT(x)
x = self.L2_bn(x)
x = self.L2_act(x)
x = self.L3_convT(x)
x = self.L3_act(x)
return x # image pixels in range -1..1
Build models, loss and optimizers
conv_dim = 32
z_size = 100
generator = Generator(z_size=z_size, conv_dim=conv_dim)
discriminator = Discriminator(conv_dim)
generator.to(device)
discriminator.to(device)
criterion = nn.BCEWithLogitsLoss()
lr, betas = 0.0002, (0.5, 0.999)
d_optimizer = optim.Adam(discriminator.parameters(), lr, betas)
g_optimizer = optim.Adam(generator.parameters(), lr, betas)
Helpers
def plot_images(x_fake): # Expects tensor shape [batch, 3, width, height]
fig, axes = plt.subplots(nrows=1, ncols=10, figsize=[16,9])
for i, ax in enumerate(axes):
img = x_fake[i].detach().cpu().permute(1, 2, 0).numpy()
ax.imshow(img/2+.5)
ax.axis('off')
plt.show()
def plot_loss(losses):
fig, ax = plt.subplots(nrows=1, ncols=1)
ax.plot(losses['disc'], label='disc')
ax.plot(losses['gen'], label='gen')
plt.show()
Few more hyperparameters. fixed_noise
is used to generate consistent sample images every epoch
n_batch = 128
n_epochs = 1 # even 1 epoch will show some face-ish results, recommended 10-100 train epochs
fixed_noise = torch.rand(16, z_size, device=device)*2-1 # uniform -1..1
Train the model
losses = {'gen':[], 'disc':[]}
indices = np.array(range(len(x_train)))
iteration = 0
for e in range(n_epochs):
time_start = time.time()
np.random.shuffle(indices)
for i in range(0, len(x_train), n_batch):
# Pick next batch of real images
i_batch = indices[i:i+n_batch]
x_real = x_train[i_batch]
# Generate fake images
noise = torch.rand(len(x_real), z_size, device=device)*2-1 # uniform -1..1
x_fake = generator(noise)
# Train the discriminator
d_optimizer.zero_grad()
outputs_real = discriminator(x_real) # logits
y_real = torch.ones(len(x_real), 1, device=device) * .9
d_real_loss = criterion(outputs_real, y_real)
outputs_fake = discriminator(x_fake) # logits
y_fake = torch.zeros(len(x_fake), 1, device=device)
d_fake_loss = criterion(outputs_fake, y_fake)
d_loss = d_real_loss + d_fake_loss
d_loss.backward()
d_optimizer.step()
# Train the generator
g_optimizer.zero_grad()
noise = torch.rand(len(x_real), z_size, device=device)*2-1 # uniform -1..1
x_fake = generator(noise)
y_fake = torch.ones(len(x_real), 1, device=device) # flip labels
outputs_fake = discriminator(x_fake)
g_loss = criterion(outputs_fake, y_fake)
g_loss.backward()
g_optimizer.step()
losses['disc'].append(d_loss.item())
losses['gen'].append(g_loss.item())
# Print some loss stats
if iteration % 100 == 0:
generator.eval() # switch batchnorm to eval mode
samples = generator(fixed_noise)
generator.train()
time_epoch = time.time() - time_start
print(f'Epoch ({time_epoch:4.1f}s): {e:3} gloss: {g_loss:4.2f} dloss: {d_loss:4.2f}')
plot_images(samples)
iteration += 1
generator.eval()
samples = generator(fixed_noise)
generator.train()
time_epoch = time.time() - time_start
print(f'Epoch ({time_epoch:4.1f}s): {e:3} gloss: {g_loss:4.2f} dloss: {d_loss:4.2f}')
plot_images(samples)
plot_loss(losses)
Code below is to help fix Keras DCGAN :(
raise
plt.hist(x_fake.detach().cpu().numpy().ravel(), bins=100);
generator.L0_fc.weight.detach().cpu().shape
p = generator.L0_fc.weight.detach().cpu().numpy().T
plt.hist(p.ravel(), bins=100)
plt.title('L0_fc - weights')
print(p.shape)
p = generator.L0_fc.bias.detach().cpu().numpy().T
plt.hist(p.ravel(), bins=100);
plt.title('L0_fc - bias');
print(p.shape)
p = generator.L1_convT.weight.detach().cpu().numpy().T
plt.hist(p.ravel(), bins=100);
plt.title('L1_convT - weight');
print(p.shape)
generator.L1_bn
plt.hist(noise.detach().cpu().numpy().ravel(), bins=100);