Introduction

This notebook presents DCGAN architecture trained on SVHN and CelebA datasets.

Depending on which dataset you want to use comment out cells for the other one. There is no need to change DCGAN code, it will work with both datasets.

References

Imports

In [1]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

SVHN Dataset

In [4]:
# from torchvision import datasets

Download dataset

In [5]:
# dataset = datasets.SVHN(root='~/.pytorch/svhn/', split='train', download=True, transform=None)

Because this is very small dataset we will load it fully into GPU. Otherwise dataloader will slow learning considerably.

In [6]:
# x_train = (dataset.data.astype(np.float32) / 127.5) - 1   # scale to range [-1..1]
# x_train = x_train[:-41]  # make divisible by 512
# print(x_train.shape)

Visualize the data

In [7]:
# fig, axes = plt.subplots(nrows=1, ncols=8, figsize=[16,6])
# for i in range(len(axes)):
#     idx = np.random.randint(0, len(x_train))
#     img = x_train[idx]
#     img = img/2 + .5         # -1..1 -> 0..1
#     img = img.transpose(1, 2, 0)
#     axes[i].imshow(img); axes[i].axis('off')

Move to GPU if available

In [8]:
# x_train = torch.tensor(x_train, device=device)

CelebA Dataset

Point this to folder with CelebA unzipped

In [9]:
dataset_location = '/home/marcin/Datasets/img_align_celeba'
In [10]:
datafile = os.path.join(dataset_location, 'img_align_celeba_32x32.npz')
print(datafile)
/home/marcin/Datasets/img_align_celeba/img_align_celeba_32x32.npz
In [11]:
import PIL

Get names of all images in CelebA data folder

In [12]:
all_files = os.listdir(os.path.join(dataset_location, 'img_align_celeba'))

Images in CelebA are pre-aligned such that eyes and mouth location on different images is almost identical. We will use this fact and hardcode cropping to 128x128 and rescale to 32x32

In [13]:
def crop_celeba(img, size):
    assert img.size == (178, 218)
    box = [25, 65, img.width-25, img.height-25]  # crop 25/65/25/25 from left/top/right/bottom
    return img.crop(box)                         # result.size = (128, 128)

Show couple images with and without cropping

In [14]:
fig, (axes1, axes2) = plt.subplots(nrows=2, ncols=6, figsize=[16,6])
for i in range(len(axes1)):
    idx = i  # np.random.randint(0, len(all_files))
    img_full_path = os.path.join(dataset_location, 'img_align_celeba', all_files[idx])
    img = PIL.Image.open(img_full_path)
    img_128x128 = crop_celeba(img, None)
    img_64x64 = img_128x128.resize([64, 64], PIL.Image.BICUBIC)
    axes1[i].imshow(img)
    axes1[i].set_title('original')
    axes2[i].imshow(img_64x64)
    axes2[i].set_title('processed')

This function will read, crop, scale all images and put them into one large numpy array

In [15]:
def load_all_images(dataset_location, all_files, size=32):
    all_images = []
    for i in range(len(all_files)):
        img_full_path = os.path.join(dataset_location, 'img_align_celeba', all_files[i])
        img = PIL.Image.open(img_full_path)
        img_128x128 = crop_celeba(img, None)
        img_small = img_128x128.resize([size, size], PIL.Image.BICUBIC)  
        arr = np.array(img_small)
        assert arr.shape == (size, size, 3)
        all_images.append(arr)
        if i % 10000 == 0:
            print(f'Image {i} of {len(all_files)}')
    return np.array(all_images)

Check if dataset was processed earlier, if not then process and save to .npz, if yes then just load from file

In [16]:
if not os.path.isfile(datafile):
    print('Creating datafile...')
    all_images = load_all_images(dataset_location, all_files, size=32)
    np.savez(datafile, all_images=all_images)
else:
    print('Loading from file...')
    npzfile = np.load(datafile)
    all_images = npzfile['all_images']
Loading from file...

Convert to float32, scale to -1..1 range and re-order dimensions to match pytorch convention

In [17]:
x_train = (all_images.astype(np.float32) / 127.5) - 1      # scale to range [-1..1]
x_train = x_train.transpose(0, 3, 1, 2)                    # pytorch ordering NCHW
x_train = x_train[:-103]                                   # make divisible by 128
x_train.shape
Out[17]:
(202496, 3, 32, 32)
In [18]:
del all_images  # save memory

Move everything to GPU if available (approx 2.5GB)

In [22]:
x_train = torch.tensor(x_train, device=device)

DCGAN

Discriminator

In [23]:
class Discriminator(nn.Module):
    def __init__(self, conv_dim=32):
        super(Discriminator, self).__init__()

        self.L1_conv = nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1, bias=False)
        self.L1_act = nn.LeakyReLU(0.2)

        self.L2_conv = nn.Conv2d(conv_dim, 2*conv_dim, kernel_size=4, stride=2, padding=1, bias=False)
        self.L2_bn = nn.BatchNorm2d(2*conv_dim)
        self.L2_act = nn.LeakyReLU(0.2)
        
        self.L3_conv = nn.Conv2d(2*conv_dim, 4*conv_dim, kernel_size=4, stride=2, padding=1, bias=False)
        self.L3_bn = nn.BatchNorm2d(4*conv_dim)
        self.L3_act = nn.LeakyReLU(0.2)
        
        self.L4_fc = nn.Linear(in_features=4*conv_dim*4*4, out_features=1)

    def forward(self, x):
        x = self.L1_conv(x)
        x = self.L1_act(x)
        
        x = self.L2_conv(x)
        x = self.L2_bn(x)
        x = self.L2_act(x)
        
        x = self.L3_conv(x)
        x = self.L3_bn(x)
        x = self.L3_act(x)
       
        x = x.view(x.size(0), -1)  # flatten
        x = self.L4_fc(x)
        return x                   # return logits

Generator

In [24]:
class Generator(nn.Module):
    def __init__(self, z_size, conv_dim=32):
        super(Generator, self).__init__()
        self.conv_dim = conv_dim
        
        self.L0_fc = nn.Linear(in_features=z_size, out_features=4*conv_dim*4*4)
        
        self.L1_convT = nn.ConvTranspose2d(4*conv_dim, 2*conv_dim, kernel_size=4, stride=2, padding=1, bias=False)
        self.L1_bn = nn.BatchNorm2d(2*conv_dim)
        self.L1_act = nn.ReLU()
        
        self.L2_convT = nn.ConvTranspose2d(2*conv_dim, conv_dim, kernel_size=4, stride=2, padding=1, bias=False)
        self.L2_bn = nn.BatchNorm2d(conv_dim)
        self.L2_act = nn.ReLU()
               
        self.L3_convT = nn.ConvTranspose2d(conv_dim, 3, kernel_size=4, stride=2, padding=1, bias=False)
        self.L3_act = nn.Tanh()
        
    def forward(self, x):
        x = self.L0_fc(x)
        x = x.view(x.size(0), 4*self.conv_dim, 4, 4)
        
        x = self.L1_convT(x)
        x = self.L1_bn(x)
        x = self.L1_act(x)
        
        x = self.L2_convT(x)
        x = self.L2_bn(x)
        x = self.L2_act(x)
        
        x = self.L3_convT(x)
        x = self.L3_act(x)
        return x                # image pixels in range -1..1

Build models, loss and optimizers

In [25]:
conv_dim = 32
z_size = 100

generator = Generator(z_size=z_size, conv_dim=conv_dim)
discriminator = Discriminator(conv_dim)

generator.to(device)
discriminator.to(device)

criterion = nn.BCEWithLogitsLoss()

lr, betas = 0.0002, (0.5, 0.999)
d_optimizer = optim.Adam(discriminator.parameters(), lr, betas)
g_optimizer = optim.Adam(generator.parameters(), lr, betas)

Helpers

In [26]:
def plot_images(x_fake):  # Expects tensor shape [batch, 3, width, height]
    fig, axes = plt.subplots(nrows=1, ncols=10, figsize=[16,9])
    for i, ax in enumerate(axes):
        img = x_fake[i].detach().cpu().permute(1, 2, 0).numpy()
        ax.imshow(img/2+.5)
        ax.axis('off')
    plt.show()
    
def plot_loss(losses):
    fig, ax = plt.subplots(nrows=1, ncols=1)
    ax.plot(losses['disc'], label='disc')
    ax.plot(losses['gen'], label='gen')
    plt.show()

Few more hyperparameters. fixed_noise is used to generate consistent sample images every epoch

In [28]:
n_batch = 128
n_epochs = 1     # even 1 epoch will show some face-ish results, recommended 10-100 train epochs

fixed_noise = torch.rand(16, z_size, device=device)*2-1   # uniform -1..1

Train the model

In [29]:
losses = {'gen':[], 'disc':[]}
indices = np.array(range(len(x_train)))

iteration = 0
for e in range(n_epochs):
    
    time_start = time.time()
    
    np.random.shuffle(indices)
    for i in range(0, len(x_train), n_batch):
        
        # Pick next batch of real images
        i_batch = indices[i:i+n_batch]
        x_real = x_train[i_batch]
        
        # Generate fake images
        noise = torch.rand(len(x_real), z_size, device=device)*2-1  # uniform -1..1
        x_fake = generator(noise)
        
        
        # Train the discriminator
        d_optimizer.zero_grad()
        
        outputs_real = discriminator(x_real)  # logits
        y_real = torch.ones(len(x_real), 1, device=device) * .9
        d_real_loss = criterion(outputs_real, y_real)
        
        outputs_fake = discriminator(x_fake)  # logits
        y_fake = torch.zeros(len(x_fake), 1, device=device)
        d_fake_loss = criterion(outputs_fake, y_fake)
        
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        
        # Train the generator
        g_optimizer.zero_grad()
        
        noise = torch.rand(len(x_real), z_size, device=device)*2-1  # uniform -1..1
        x_fake = generator(noise)
        y_fake = torch.ones(len(x_real), 1, device=device)   # flip labels
        outputs_fake = discriminator(x_fake)
        
        g_loss = criterion(outputs_fake, y_fake)
        g_loss.backward()
        g_optimizer.step()
        
        losses['disc'].append(d_loss.item())
        losses['gen'].append(g_loss.item())
        
        # Print some loss stats
        if iteration % 100 == 0:
            
            generator.eval()                   # switch batchnorm to eval mode
            samples = generator(fixed_noise)
            generator.train()
    
            time_epoch = time.time() - time_start
            print(f'Epoch ({time_epoch:4.1f}s): {e:3}     gloss: {g_loss:4.2f}     dloss: {d_loss:4.2f}')
            plot_images(samples)
            
        iteration += 1

    
    generator.eval()
    samples = generator(fixed_noise)
    generator.train()
    
    time_epoch = time.time() - time_start
    print(f'Epoch ({time_epoch:4.1f}s): {e:3}     gloss: {g_loss:4.2f}     dloss: {d_loss:4.2f}')
    plot_images(samples)
    plot_loss(losses)
Epoch ( 0.0s:   0     gloss: 0.95     dloss: 1.34
Epoch ( 1.9s:   0     gloss: 3.77     dloss: 0.40
Epoch ( 3.7s:   0     gloss: 2.11     dloss: 0.60
Epoch ( 5.6s:   0     gloss: 2.21     dloss: 0.71
Epoch ( 7.5s:   0     gloss: 2.36     dloss: 0.77
Epoch ( 9.3s:   0     gloss: 1.50     dloss: 0.85
Epoch (11.1s:   0     gloss: 2.08     dloss: 0.82
Epoch (12.9s:   0     gloss: 1.59     dloss: 0.91
Epoch (14.8s:   0     gloss: 1.99     dloss: 0.97
Epoch (16.5s:   0     gloss: 1.54     dloss: 0.91
Epoch (18.3s:   0     gloss: 1.52     dloss: 0.83
Epoch (20.2s:   0     gloss: 1.57     dloss: 1.00
Epoch (22.0s:   0     gloss: 1.19     dloss: 0.96
Epoch (23.9s:   0     gloss: 1.32     dloss: 0.86
Epoch (25.8s:   0     gloss: 1.32     dloss: 0.87
Epoch (27.6s:   0     gloss: 1.25     dloss: 1.00
epoch 29.094604s:   0     gloss: 1.48     dloss: 0.96

Debug code below this point

Code below is to help fix Keras DCGAN :(

In [ ]:
raise
In [ ]:
 
In [29]:
plt.hist(x_fake.detach().cpu().numpy().ravel(), bins=100);
In [35]:
generator.L0_fc.weight.detach().cpu().shape
Out[35]:
torch.Size([100, 2048])
In [46]:
p = generator.L0_fc.weight.detach().cpu().numpy().T
plt.hist(p.ravel(), bins=100)
plt.title('L0_fc - weights')
print(p.shape)
(100, 2048)
In [47]:
p = generator.L0_fc.bias.detach().cpu().numpy().T
plt.hist(p.ravel(), bins=100);
plt.title('L0_fc - bias');
print(p.shape)
(2048,)
In [48]:
p = generator.L1_convT.weight.detach().cpu().numpy().T
plt.hist(p.ravel(), bins=100);
plt.title('L1_convT - weight');
print(p.shape)
(4, 4, 64, 128)
In [49]:
generator.L1_bn
Out[49]:
BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
In [ ]:
 
In [51]:
plt.hist(noise.detach().cpu().numpy().ravel(), bins=100);
In [ ]: