Introduction

This notebook presents LSTM network trained to perform sentiment analysis on IMDB movie reviews dataset.

Contents

Imports

In [1]:
import os
import re
import time
import collections
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

Pick GPU if available

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

IMDB Dataset

Download the dataset from here and extract. Point path below to extracted location.

In [3]:
dataset_location = '/home/marcin/Datasets/imdb'

Helper to load the dataset

In [4]:
def load_imdb_dataset(dataset_loc):
    def read_reviews(path, label, reviews, labels):
        files_list = sorted(os.listdir(path))
        for filename in sorted(os.listdir(path)):
            with open(os.path.join(path, filename)) as f:
                reviews.append(f.read())
                labels.append(label)
        return reviews, labels
    
    path_train_pos = os.path.join(dataset_location, 'aclImdb_v1/aclImdb/train/pos')
    path_train_neg = os.path.join(dataset_location, 'aclImdb_v1/aclImdb/train/neg')
    path_test_pos = os.path.join(dataset_location, 'aclImdb_v1/aclImdb/test/pos')
    path_test_neg = os.path.join(dataset_location, 'aclImdb_v1/aclImdb/test/neg')
    
    train_revs, train_labels = [], []
    train_revs, train_labels = read_reviews(path_train_pos, 1, train_revs, train_labels)
    train_revs, train_labels = read_reviews(path_train_neg, 0, train_revs, train_labels)
    
    test_revs, test_labels = [], []
    test_revs, test_labels = read_reviews(path_test_pos, 1, test_revs, test_labels)
    test_revs, test_labels = read_reviews(path_test_neg, 0, test_revs, test_labels)
    
    return (train_revs, train_labels), (test_revs, test_labels)

Load dataset

In [5]:
train_data, test_data = load_imdb_dataset(dataset_location)
train_reviews_raw, train_labels_raw = train_data
test_reviews_raw, test_labels_raw = test_data   

Look at the Data

Lets see a sample review

In [6]:
print(train_reviews_raw[0])
Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as "Teachers". My 35 years in the teaching profession lead me to believe that Bromwell High's satire is much closer to reality than is "Teachers". The scramble to survive financially, the insightful students who can see right through their pathetic teachers' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I'm here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn't!

Count words in the dataset

In [7]:
def count_words(list_of_examples):
    if isinstance(list_of_examples[0], str):
        split = True  # got list of strings, need to split words
    if isinstance(list_of_examples[0], list):
        split = False  # list of lists, already split by words
    
    words_counter = collections.Counter()
    for example in list_of_examples:
        if split:
            words_counter.update(example.split())
        else:
            words_counter.update(example)
            
    total_words = sum(list(words_counter.values()))
    unique_words = len(words_counter)
    
    return total_words, unique_words, words_counter
In [8]:
total_words, unique_words, words_counter = count_words(train_reviews_raw)
print('Total words: ', total_words)
print('Unique words: ', unique_words)
Total words:  5844680
Unique words:  280617

We have 5.8M words (as separated by spaces) and 280k unique words

And lets have a look at word count distributions

In [9]:
def plot_counts(words_counter, title):
    sorted_all = np.array(sorted(list(words_counter.values()), reverse=True))
    fig, [ax1, ax2] = plt.subplots(1, 2, figsize=[16,6])
    ax1.plot(sorted_all); ax1.set_title(title + ' Counts (linear scale)')
    ax2.plot(sorted_all); ax2.set_title(title + ' Counts (log scale)')
    ax2.set_yscale('log')
In [10]:
plot_counts(words_counter, title='Word')

Some words appear 300k times (left plot), while there is over 150k "words" that appear only once (right plot)

Preprocess Data

We are going to perform following pre-processing steps:

  • text cleanup - convert to lowercase and remove any non a-z characters
  • remove stopwords - remove words like 'the', 'a', 'an' and so on
  • reduce vocabulary - keep 1000 most common words (same as tf.keras.datasets.imbd)
  • tokenize - convert words to ints

Text Cleanup

We are going to perform following pre-processing steps:

  • convert to lowercase
  • keep only a-z characters, convert everything else to space
    • (we could spell digits, i.e. convert 20 to 'two zero' but we won't bother here)
  • split whatever is left thus removing any consequent spaces

This will leave us with dataset build of 26 letters.

Note that words like "didn't" will be converted to "didn t", but that's ok. Words "did" and "didn" will still be encoded as different characters. Word "t" can be dropped when removing stopwords.

NOTE: Sometimes keeping apostrophe in regex helps
In [12]:
def text_cleanup(list_of_texts):
    """Perform text cleanup, reduce to a-z and space."""
    def cleanup(text):
        res = text.lower()
        res = regex.sub(' ', res)
        return res.split()
    
    result_cleaned = []
    regex = re.compile('[^a-z ]+')    # removes everything that is not a-z
    for text in list_of_texts:
        result_cleaned.append(cleanup(text))
    return result_cleaned                     # doubly nested list of words
In [13]:
train_reviews = text_cleanup(train_reviews_raw)
test_reviews = text_cleanup(test_reviews_raw)
print(train_reviews[0])
['bromwell', 'high', 'is', 'a', 'cartoon', 'comedy', 'it', 'ran', 'at', 'the', 'same', 'time', 'as', 'some', 'other', 'programs', 'about', 'school', 'life', 'such', 'as', 'teachers', 'my', 'years', 'in', 'the', 'teaching', 'profession', 'lead', 'me', 'to', 'believe', 'that', 'bromwell', 'high', 's', 'satire', 'is', 'much', 'closer', 'to', 'reality', 'than', 'is', 'teachers', 'the', 'scramble', 'to', 'survive', 'financially', 'the', 'insightful', 'students', 'who', 'can', 'see', 'right', 'through', 'their', 'pathetic', 'teachers', 'pomp', 'the', 'pettiness', 'of', 'the', 'whole', 'situation', 'all', 'remind', 'me', 'of', 'the', 'schools', 'i', 'knew', 'and', 'their', 'students', 'when', 'i', 'saw', 'the', 'episode', 'in', 'which', 'a', 'student', 'repeatedly', 'tried', 'to', 'burn', 'down', 'the', 'school', 'i', 'immediately', 'recalled', 'at', 'high', 'a', 'classic', 'line', 'inspector', 'i', 'm', 'here', 'to', 'sack', 'one', 'of', 'your', 'teachers', 'student', 'welcome', 'to', 'bromwell', 'high', 'i', 'expect', 'that', 'many', 'adults', 'of', 'my', 'age', 'think', 'that', 'bromwell', 'high', 'is', 'far', 'fetched', 'what', 'a', 'pity', 'that', 'it', 'isn', 't']

Lets see how it looks like now

In [14]:
total_words, unique_words, words_counter = count_words(train_reviews)
print('Total words: ', total_words)
print('Unique words: ', unique_words)
Total words:  6023662
Unique words:  73272

Remove Stopwords

NOTE: This step doesn't provide much improvement on this task

Check most common words, they don't contribute to overall meaning of sentences

In [17]:
display(words_counter.most_common()[:10])
[('the', 336758),
 ('and', 164143),
 ('a', 163174),
 ('of', 145867),
 ('to', 135724),
 ('is', 107337),
 ('br', 101872),
 ('it', 96472),
 ('in', 93981),
 ('i', 87702)]

List of stopwords from NLTK

In [18]:
# import nltk
# nltk.download('stopwords')
# en_stopwords = nltk.corpus.stopwords.words('english')
# stopwords = {sw for sw in en_stopwords}
# print(stopwords)
stopwords = {'down', 'then', 'of', 'but', 'only', 'yours', 'himself', 'again',
             'very', 'or', 'once', 'until', 'have', "doesn't", 'what', 'during',
             "that'll", 'some', 'was', 'be', 'he', "should've", 'between',
             "shouldn't", 'further', 'no', 'yourself', 'm', 've', "you'll",
             'ain', 't', 'our', 'his', 'o', 'wouldn', 'below', 'any', 'under',
             'you', 'isn', 'theirs', 'why', 'that', 'mightn', 'ourselves', 'on',
             'haven', 'while', 'to', 'than', 'your', 'she', 'is', 'just',
             "mightn't", 'with', "you've", 'mustn', 'needn', 'same', 'me',
             'such', 'myself', 'there', 'own', 'this', 're', 'ma', 'from',
             'did', 'couldn', 'hasn', 'for', 'won', "won't", "mustn't", 'her',
             'can', 'doesn', "wouldn't", 'when', "you're", 'who', 'which', 'll',
             'itself', 'against', 'out', 'up', "it's", 'a', 'here', 'being',
             'they', 'as', 'didn', 'weren', 'aren', 'herself', 'the', 'if',
             "didn't", 'should', 'doing', 'other', 'has', 'so', "you'd",
             'above', 'do', 'before', 'at', 'had', 'each', "aren't", 'their',
             'now', 'an', 'through', 'how', 'those', 'nor', "hasn't", 'over',
             'by', 'into', 'themselves', 'most', 'shan', 'been', "she's",
             "haven't", "isn't", "wasn't", 'where', 'about', 'in', "hadn't",
             'because', 'too', 'whom', 'ours', 'him', 'yourselves', 'after',
             'and', 'were', 'both', 'will', 'it', 'my', 'few', 'having', 'them',
             'hadn', 'shouldn', 'does', 's', "couldn't", 'y', 'all', 'don',
             'off', 'more', 'am', 'd', 'hers', 'its', 'are', "shan't",
             "weren't", 'we', "needn't", 'i', 'these', "don't", 'wasn', 'not'}
In [19]:
stopwords.add('br')  # <br /> tag in a lot of reviews

Remove stopwords

In [20]:
def remove_stopwords(list_of_examples, stopwords):
    result_no_stop = []
    for list_of_words in list_of_examples:
        result_no_stop.append( [w for w in list_of_words if w not in stopwords])
    return result_no_stop
In [21]:
train_reviews_no_stop = remove_stopwords(train_reviews, stopwords)
test_reviews_no_stop = remove_stopwords(test_reviews, stopwords)

Show sample review

In [22]:
print(train_reviews_no_stop[0])
['bromwell', 'high', 'cartoon', 'comedy', 'ran', 'time', 'programs', 'school', 'life', 'teachers', 'years', 'teaching', 'profession', 'lead', 'believe', 'bromwell', 'high', 'satire', 'much', 'closer', 'reality', 'teachers', 'scramble', 'survive', 'financially', 'insightful', 'students', 'see', 'right', 'pathetic', 'teachers', 'pomp', 'pettiness', 'whole', 'situation', 'remind', 'schools', 'knew', 'students', 'saw', 'episode', 'student', 'repeatedly', 'tried', 'burn', 'school', 'immediately', 'recalled', 'high', 'classic', 'line', 'inspector', 'sack', 'one', 'teachers', 'student', 'welcome', 'bromwell', 'high', 'expect', 'many', 'adults', 'age', 'think', 'bromwell', 'high', 'far', 'fetched', 'pity']

And word counts

In [23]:
total_words, unique_words, words_counter = count_words(train_reviews_no_stop)
print('Total words: ', total_words)
print('Unique words: ', unique_words)
Total words:  2988387
Unique words:  73118

Reduce Vocabulary

Likewise, check most rare words. They also don't provide much meaning (what is "lagomorph" anyways?)

In [24]:
display(words_counter.most_common()[-10:])
[('lagomorph', 1),
 ('ziller', 1),
 ('deamon', 1),
 ('yaks', 1),
 ('hoodies', 1),
 ('insulation', 1),
 ('mwuhahahaa', 1),
 ('bellwood', 1),
 ('pressurized', 1),
 ('whelk', 1)]

We will reduce vocabulary to 998 words plus <PAD> and <UNK> tokens for total of 1000 words

In [25]:
def get_most_common_words(list_of_examples, num_words):
    words_ctr = collections.Counter()
    for example in list_of_examples:
        words_ctr.update(example)
    
    keep_words = {w for w, n in words_ctr.most_common()[:num_words]}
    return keep_words
In [26]:
allowed_words = get_most_common_words(train_reviews_no_stop, 9998)

Print some of the allowed words

In [27]:
print([w for w in allowed_words][:20])
['chuckles', 'half', 'productions', 'kidding', 'ewoks', 'strip', 'wisdom', 'farnsworth', 'official', 'pounds', 'presidential', 'con', 'treat', 'cringing', 'spencer', 'bio', 'loretta', 'infant', 'prove', 'gag']

And reduce vocabulary

In [28]:
def reduce_vocabulary(list_of_examples, allowed_words, unk_tok='<UNK>'):
    result_reduced = []
    for example in list_of_examples:
        result_reduced.append( [w if w in allowed_words else unk_tok for w in example] )
    return result_reduced
In [29]:
train_reviews_reduced = reduce_vocabulary(train_reviews_no_stop, allowed_words)
test_reviews_reduced = reduce_vocabulary(test_reviews_no_stop, allowed_words)

Show example after reduction

In [30]:
print(train_reviews_reduced[0])
['<UNK>', 'high', 'cartoon', 'comedy', 'ran', 'time', 'programs', 'school', 'life', 'teachers', 'years', 'teaching', 'profession', 'lead', 'believe', '<UNK>', 'high', 'satire', 'much', 'closer', 'reality', 'teachers', '<UNK>', 'survive', '<UNK>', 'insightful', 'students', 'see', 'right', 'pathetic', 'teachers', '<UNK>', '<UNK>', 'whole', 'situation', 'remind', 'schools', 'knew', 'students', 'saw', 'episode', 'student', 'repeatedly', 'tried', 'burn', 'school', 'immediately', '<UNK>', 'high', 'classic', 'line', 'inspector', 'sack', 'one', 'teachers', 'student', 'welcome', '<UNK>', 'high', 'expect', 'many', 'adults', 'age', 'think', '<UNK>', 'high', 'far', 'fetched', 'pity']

And count words

In [31]:
total_words, unique_words, words_counter = count_words(train_reviews_reduced)
print('Total words: ', total_words)
print('Unique words: ', unique_words)
Total words:  2988387
Unique words:  9999

Create dictionaries

Technically we don't do any padding in this notebook but I'm leaving "<PAD>" in anyway

In [32]:
i2w = {i : w for i, (w, c) in enumerate(words_counter.most_common(), 1)}
w2i = {w : i for i, w in i2w.items()}
i2w[0] = '<PAD>'                       # use zero index for padding
w2i[i2w[0]] = 0
print('Number of words in dictionaries:', len(i2w))
Number of words in dictionaries: 10000

And confirm dictionaries are build correctly

In [33]:
for i in range(10):
    word = i2w[i]
    print(i, ':', word, ':', w2i[word])
0 : <PAD> : 0
1 : <UNK> : 1
2 : movie : 2
3 : film : 3
4 : one : 4
5 : like : 5
6 : good : 6
7 : time : 7
8 : even : 8
9 : would : 9

Print subset of vocabulary

In [34]:
print(sorted(list(i2w.values()))[:100])
['<PAD>', '<UNK>', 'aaron', 'abandon', 'abandoned', 'abbott', 'abc', 'abducted', 'abilities', 'ability', 'able', 'aboard', 'abominable', 'abomination', 'abortion', 'abound', 'abraham', 'abroad', 'abrupt', 'abruptly', 'absence', 'absent', 'absolute', 'absolutely', 'absorbed', 'absorbing', 'abstract', 'absurd', 'absurdity', 'abu', 'abundance', 'abuse', 'abused', 'abusive', 'abysmal', 'academic', 'academy', 'accent', 'accents', 'accept', 'acceptable', 'acceptance', 'accepted', 'accepting', 'accepts', 'access', 'accessible', 'accident', 'accidental', 'accidentally', 'acclaim', 'acclaimed', 'accompanied', 'accompany', 'accompanying', 'accomplish', 'accomplished', 'accomplishment', 'according', 'account', 'accounts', 'accuracy', 'accurate', 'accurately', 'accused', 'ace', 'achieve', 'achieved', 'achievement', 'achievements', 'achieves', 'achieving', 'acid', 'acknowledge', 'acknowledged', 'acquire', 'acquired', 'across', 'act', 'acted', 'acting', 'action', 'actions', 'active', 'activities', 'activity', 'actor', 'actors', 'actress', 'actresses', 'acts', 'actual', 'actuality', 'actually', 'ad', 'adam', 'adams', 'adapt', 'adaptation', 'adaptations']

Tokenize

Convert words into integer tokens

In [35]:
def tokenize(list_of_examples, word2idx):
    result_tokenized = []
    for list_of_words in list_of_examples:
        result_tokenized.append( [word2idx[w] for w in list_of_words] )
    return result_tokenized
In [36]:
train_reviews_tok = tokenize(train_reviews_reduced, w2i)
test_reviews_tok = tokenize(test_reviews_reduced, w2i)

Show example

In [37]:
print(train_reviews_tok[0])
[1, 194, 916, 102, 1994, 7, 5643, 264, 34, 5050, 59, 4829, 5707, 354, 148, 1, 194, 1839, 14, 2254, 486, 5050, 1, 1844, 1, 5644, 1363, 12, 99, 1067, 5050, 1, 1, 115, 752, 2848, 5645, 559, 1363, 110, 265, 1275, 3544, 653, 3344, 264, 1080, 1, 194, 233, 225, 2771, 7543, 4, 5050, 1275, 2185, 1, 194, 404, 33, 1314, 418, 26, 1, 194, 120, 4001, 2072]

Trimming and Padding

Lets have a look at review lengths next

In [38]:
lengths = np.array([len(r) for r in train_reviews_tok])
lenghts_counter = collections.Counter(lengths)
In [39]:
print('Shortest review:', lengths.min())
print('Longest review: ', lengths.max())
Shortest review: 4
Longest review:  1422
In [40]:
plot_counts(lenghts_counter, title='Length')

Lets see how many reviews are over 200 words long

In [41]:
num_over_200 = (lengths > 200).sum()
precentage_over_200 = 100*num_over_200 / len(lengths)
print('Num over 200:        ', num_over_200)
print('Precentage over 200: ', round(precentage_over_200, 2))
Num over 200:         3541
Precentage over 200:  14.16

To feed reviews into neural network in mini-batches all reviews have to have same length.

We will:

  • trim all reviews to 200 words
  • pad all shorter reviews with <NOP>
In [42]:
def trim_and_pad(list_of_examples, target_len, dtype):
    result_np = np.zeros(shape=(len(list_of_examples), target_len), dtype=dtype)
    
    for i in range(len(list_of_examples)):
        example_trimmed = list_of_examples[i][:target_len]      # trim
        start = target_len - len(example_trimmed)
        result_np[i, start:target_len] = example_trimmed
        
    return result_np

At 10k words in the vocabulary we will fit in the int16 type

In [43]:
target_type = np.int16
assert len(i2w) < np.iinfo(target_type).max
train_reviews_np = trim_and_pad(train_reviews_tok, target_len=200, dtype=target_type)
test_reviews_np = trim_and_pad(test_reviews_tok, target_len=200, dtype=target_type)
In [44]:
print(train_reviews_np)
[[   0    0    0 ...  120 4001 2072]
 [   1    1  578 ...    1  161  166]
 [   0    0    0 ...   12   78  276]
 ...
 [   0    0    0 ...   15    2   16]
 [   0    0    0 ...    1 1043  650]
 [   0    0    0 ...  311    7 1194]]
In [45]:
print(train_reviews_np[0])
[   0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    0    0    0    0    0    0    0    0    0
    0    0    0    0    0    1  194  916  102 1994    7 5643  264   34
 5050   59 4829 5707  354  148    1  194 1839   14 2254  486 5050    1
 1844    1 5644 1363   12   99 1067 5050    1    1  115  752 2848 5645
  559 1363  110  265 1275 3544  653 3344  264 1080    1  194  233  225
 2771 7543    4 5050 1275 2185    1  194  404   33 1314  418   26    1
  194  120 4001 2072]
In [46]:
' '.join(i2w[c] for c in train_reviews_np[0] if c != 0)
Out[46]:
'<UNK> high cartoon comedy ran time programs school life teachers years teaching profession lead believe <UNK> high satire much closer reality teachers <UNK> survive <UNK> insightful students see right pathetic teachers <UNK> <UNK> whole situation remind schools knew students saw episode student repeatedly tried burn school immediately <UNK> high classic line inspector sack one teachers student welcome <UNK> high expect many adults age think <UNK> high far fetched pity'

Convert Labels

In [47]:
train_labels = np.array(train_labels_raw).reshape(-1, 1)
test_labels = np.array(test_labels_raw).reshape(-1, 1)
In [48]:
print(train_labels)
[[1]
 [1]
 [1]
 ...
 [0]
 [0]
 [0]]

Convert to Tensors

In [49]:
train_features = torch.tensor(train_reviews_np, dtype=torch.int64, device=device)
train_targets = torch.tensor(train_labels, dtype=torch.float32, device=device)
test_features = torch.tensor(test_reviews_np, dtype=torch.int64, device=device)
test_targets = torch.tensor(test_labels, dtype=torch.float32, device=device)
In [50]:
print(train_features.shape, train_features.dtype)
print(train_targets.shape, train_targets.dtype)
print(test_features.shape, test_features.dtype)
print(test_targets.shape, test_targets.dtype)
torch.Size([25000, 200]) torch.int64
torch.Size([25000, 1]) torch.float32
torch.Size([25000, 200]) torch.int64
torch.Size([25000, 1]) torch.float32

PyTorch Model

Helper function for accuracy

In [51]:
def accuracy(pred, tar): 
    return (pred == tar).float().mean()  # tensor!!

Evaluate Helper

In [53]:
def evaluate(data_features, data_targets, batch_size):
    
    predictions_all = torch.zeros_like(data_targets)   # model outputs
    
    model.eval()
    
    loss_sum, acc_sum = 0, 0
    with torch.no_grad():        
        for i in range(0, len(data_features), batch_size):

            # Pick mini-batch
            inputs = data_features[i:i+batch_size]
            targets = data_targets[i:i+batch_size]

            # Forward pass
            logits = model(inputs)
            loss = criterion(logits, targets)
            probabilities = torch.sigmoid(logits)
            predictions = probabilities.round()
            acc = accuracy(predictions, targets)

            predictions_all[i:i+batch_size] = predictions
            
            # Record per-iteration loss
            loss_sum += loss.item() * len(inputs)
            acc_sum += acc.item() * len(inputs)
        
    loss_avg = loss_sum / len(data_features)
    acc_avg = acc_sum / len(data_features)
    
    return loss_avg, acc_avg, predictions_all

Helper for plotting

In [54]:
def plot_trace(trace):
    fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(12, 4))
    ax1.plot(trace['tloss'], label='train loss')
    ax1.plot(trace['vloss'], label='valid loss')
    ax1.set_xlabel('Epoch')
    ax1.legend()
    
    ax2.plot(trace['tacc'], label='train acc')
    ax2.plot(trace['vacc'], label='valid acc')
    ax2.set_xlabel('Epoch')
    ax2.legend()
    plt.show()

Simple LSTM model

In [55]:
class SentimentNetwork(nn.Module):
    def __init__(self, nb_layers, n_vocab, n_embed, n_hid, n_out, dropout):
        super(SentimentNetwork, self).__init__()
        
        self.embed = nn.Embedding(num_embeddings=n_vocab, embedding_dim=n_embed)
        self.lstm = nn.LSTM(input_size=n_embed, hidden_size=n_hid, num_layers=nb_layers,
                            batch_first=True, dropout=dropout)
        self.drop = nn.Dropout(p=dropout)
        self.fc = nn.Linear(in_features=n_hid, out_features=n_out)
        
    def forward(self, x):
        x = self.embed(x)                   # shape [n_batch, n_seq, n_embed]
        x, _ = self.lstm(x)                 # shape [n_batch, n_seq, n_hid]
        x = x[:, -1, :]                     # shape [n_batch, n_hid]; -1 last pred. only
        return self.fc(x)                   # shape [n_batch, n_out]

Hyperparameters

In [94]:
nb_layers = 2
n_vocab = len(i2w)
n_embed = 300 # 400  # 50  # 300
n_hid = 512   # 50
n_out = 1

droput = .8

Create model

In [95]:
model = SentimentNetwork(nb_layers, n_vocab, n_embed, n_hid, n_out, dropout=droput)
model.to(device)
optimizer = torch.optim.Adam(model.parameters())
criterion = nn.BCEWithLogitsLoss()

Train model

In [96]:
batch_size = 250 # 50 # 250
hist = { 'loss':[], 'acc':[] }
trace = {'epoch': [], 'tloss': [], 'vloss': [], 'tacc': [], 'vacc': []}
In [97]:
for epoch in range(10):  # loop over the dataset multiple times

    time_start = time.time()
    
    #
    #   Train
    #
    model.train()
    tloss_sum, tacc_sum = 0, 0
    indices = torch.randperm(len(train_features), device=device)
    for i in range(0, len(train_features), batch_size):

        # Pick mini-batch
        inputs = train_features[indices[i:i+batch_size]]
        targets = train_targets[indices[i:i+batch_size]]
        
        # Optimize
        optimizer.zero_grad()
        logits = model(inputs)
        loss = criterion(logits, targets)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        
        # Record
        with torch.no_grad():
            probabilities = torch.sigmoid(logits)
            predictions = probabilities.round()
            acc = accuracy(predictions, targets)
            hist['loss'].append( loss.item() )
            hist['acc'].append( acc.item() )
            tloss_sum += loss.item() * len(inputs)
            tacc_sum  += acc.item() * len(inputs)
            
    tloss_avg = tloss_sum / len(train_features)
    tacc_avg = tacc_sum / len(train_features)
    
    #
    #   Evaluate
    #
    
    # includes model.eval() and torch.no_grad()
    vloss_avg, vacc_avg, _ = evaluate(test_features, test_targets, batch_size)
    
    #
    #   Logging
    #
    
    time_delta = time.time() - time_start
    
    trace['epoch'].append(epoch)
    trace['tloss'].append(tloss_avg)
    trace['tacc'].append(tacc_avg)
    trace['vloss'].append(vloss_avg)
    trace['vacc'].append(vacc_avg)
    
    print(f'Epoch: {epoch:3}     T/V Loss: {tloss_avg:.4f} / {vloss_avg:.4f}     '
          f'T/V Acc: {tacc_avg:.4f} / {vacc_avg:.4f}     Time: {time_delta:.2f}s')
Epoch:   0     T/V Loss: 0.5742 / 0.5196     T/V Acc: 0.6982 / 0.7462     Time: 37.13s
Epoch:   1     T/V Loss: 0.4355 / 0.4968     T/V Acc: 0.8052 / 0.7860     Time: 39.28s
Epoch:   2     T/V Loss: 0.3783 / 0.4107     T/V Acc: 0.8380 / 0.8260     Time: 38.74s
Epoch:   3     T/V Loss: 0.3218 / 0.3707     T/V Acc: 0.8705 / 0.8440     Time: 38.84s
Epoch:   4     T/V Loss: 0.2530 / 0.3673     T/V Acc: 0.9014 / 0.8393     Time: 39.04s
Epoch:   5     T/V Loss: 0.2221 / 0.3734     T/V Acc: 0.9150 / 0.8460     Time: 38.98s
Epoch:   6     T/V Loss: 0.1843 / 0.3897     T/V Acc: 0.9323 / 0.8497     Time: 39.14s
Epoch:   7     T/V Loss: 0.1501 / 0.4220     T/V Acc: 0.9442 / 0.8520     Time: 39.18s
Epoch:   8     T/V Loss: 0.1129 / 0.4687     T/V Acc: 0.9607 / 0.8465     Time: 39.45s
Epoch:   9     T/V Loss: 0.0838 / 0.5234     T/V Acc: 0.9724 / 0.8463     Time: 39.29s

Final Results

In [98]:
_, acc, _ = evaluate(train_features, train_targets, batch_size)
print(f'Accuracy on train set: {acc:.2f}')
Accuracy on train set: 0.99
In [99]:
_, acc, _ = evaluate(test_features, test_targets, batch_size)
print(f'Accuracy on test set: {acc:.2f}')
Accuracy on test set: 0.85

Plot training statistics

In [100]:
plot_trace(trace)
In [101]:
plt.plot(hist['loss'], label='loss')
plt.plot(hist['acc'], label='acc', color='red')
plt.legend();
In [ ]: