Jovian
⭐️
Sign In

This is a mirror of the official fast.ai course notebook for the Dsnet meetup. Please check the course repo for the latest updates

This notebook was part of Lesson 7 of the Practical Deep Learning for Coders course.

Predicting English word version of numbers using an RNN

We were using RNNs as part of our language model in the previous lesson. Today, we will dive into more details of what RNNs are and how they work. We will do this using the problem of trying to predict the English word version of numbers.

Let's predict what should come next in this sequence:

eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve...

Jeremy created this synthetic dataset to have a better way to check if things are working, to debug, and to understand what was going on. When experimenting with new ideas, it can be nice to have a smaller dataset to do so, to quickly get a sense of whether your ideas are promising (for other examples, see Imagenette and Imagewoof) This English word numbers will serve as a good dataset for learning about RNNs. Our task today will be to predict which word comes next when counting.

In deep learning, there are 2 types of numbers

Parameters are numbers that are learned. Activations are numbers that are calculated (by affine functions & element-wise non-linearities).

When you learn about any new concept in deep learning, ask yourself: is this a parameter or an activation?

Note to self: Point out the hidden state, going from the version without a for-loop to the for loop. This is the step where people get confused.

Data

In [2]:
from fastai.text import *
In [3]:
bs=64
In [4]:
path = untar_data(URLs.HUMAN_NUMBERS)
path.ls()
Out[4]:
[PosixPath('/home/racheltho/.fastai/data/human_numbers/models'),
 PosixPath('/home/racheltho/.fastai/data/human_numbers/valid.txt'),
 PosixPath('/home/racheltho/.fastai/data/human_numbers/train.txt')]
In [5]:
def readnums(d): return [', '.join(o.strip() for o in open(path/d).readlines())]

train.txt gives us a sequence of numbers written out as English words:

In [6]:
train_txt = readnums('train.txt'); train_txt[0][:80]
Out[6]:
'one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirt'
In [7]:
valid_txt = readnums('valid.txt'); valid_txt[0][-80:]
Out[7]:
' nine thousand nine hundred ninety eight, nine thousand nine hundred ninety nine'
In [8]:
train = TextList(train_txt, path=path)
valid = TextList(valid_txt, path=path)

src = ItemLists(path=path, train=train, valid=valid).label_for_lm()
data = src.databunch(bs=bs)
In [9]:
train[0].text[:80]
Out[9]:
'xxbos one , two , three , four , five , six , seven , eight , nine , ten , eleve'
In [10]:
len(data.valid_ds[0][0].data)
Out[10]:
13017

bptt stands for back-propagation through time. This tells us how many steps of history we are considering.

In [11]:
data.bptt, len(data.valid_dl)
Out[11]:
(70, 3)

We have 3 batches in our validation set:

13017 tokens, with about ~70 tokens in about a line of text, and 64 lines of text per batch.

In [12]:
13017/70/bs
Out[12]:
2.905580357142857

We will store each batch in a separate variable, so we can walk through this to understand better what the RNN does at each step:

In [13]:
it = iter(data.valid_dl)
x1,y1 = next(it)
x2,y2 = next(it)
x3,y3 = next(it)
it.close()
In [14]:
x1
Out[14]:
tensor([[ 2, 19, 11,  ..., 36,  9, 19],
        [ 9, 19, 11,  ..., 24, 20,  9],
        [11, 27, 18,  ...,  9, 19, 11],
        ...,
        [20, 11, 20,  ..., 11, 20, 10],
        [20, 11, 20,  ..., 24,  9, 20],
        [20, 10, 26,  ..., 20, 11, 20]], device='cuda:0')

numel() is a PyTorch method to return the number of elements in a tensor:

In [15]:
x1.numel()+x2.numel()+x3.numel()
Out[15]:
13440
In [16]:
x1.shape, y1.shape
Out[16]:
(torch.Size([64, 70]), torch.Size([64, 70]))
In [17]:
x2.shape, y2.shape
Out[17]:
(torch.Size([64, 70]), torch.Size([64, 70]))
In [18]:
x3.shape, y3.shape
Out[18]:
(torch.Size([64, 70]), torch.Size([64, 70]))
In [19]:
v = data.valid_ds.vocab
In [20]:
v.itos
Out[20]:
['xxunk',
 'xxpad',
 'xxbos',
 'xxeos',
 'xxfld',
 'xxmaj',
 'xxup',
 'xxrep',
 'xxwrep',
 ',',
 'hundred',
 'thousand',
 'one',
 'two',
 'three',
 'four',
 'five',
 'six',
 'seven',
 'eight',
 'nine',
 'twenty',
 'thirty',
 'forty',
 'fifty',
 'sixty',
 'seventy',
 'eighty',
 'ninety',
 'ten',
 'eleven',
 'twelve',
 'thirteen',
 'fourteen',
 'fifteen',
 'sixteen',
 'seventeen',
 'eighteen',
 'nineteen']
In [21]:
x1[:,0]
Out[21]:
tensor([ 2,  9, 11, 12, 13, 11, 10,  9, 10, 14, 19, 25, 19, 15, 16, 11, 19,  9,
        10,  9, 19, 25, 19, 11, 19, 11, 10,  9, 19, 20, 11, 26, 20, 23, 20, 20,
        24, 20, 11, 14, 11, 11,  9, 14,  9, 20, 10, 20, 35, 17, 11, 10,  9, 17,
         9, 20, 10, 20, 11, 20, 11, 20, 20, 20], device='cuda:0')
In [22]:
y1[:,0]
Out[22]:
tensor([19, 19, 27, 10,  9, 12, 32, 19, 26, 10, 11, 15, 11, 10,  9, 15, 11, 19,
        26, 19, 11, 18, 11, 18,  9, 18, 21, 19, 10, 10, 20,  9, 11, 16, 11, 11,
        13, 11, 13,  9, 13, 14, 20, 10, 20, 11, 24, 11,  9,  9, 16, 17, 20, 10,
        20, 11, 24, 11, 19,  9, 19, 11, 11, 10], device='cuda:0')
In [23]:
v.itos[9], v.itos[11], v.itos[12], v.itos[13], v.itos[10]
Out[23]:
(',', 'thousand', 'one', 'two', 'hundred')
In [24]:
v.textify(x1[0])
Out[24]:
'xxbos eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight'
In [25]:
v.textify(x1[1])
Out[25]:
', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine ,'
In [26]:
v.textify(x2[1])
Out[26]:
'eight thousand sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand'
In [21]:
v.textify(y1[0])
Out[21]:
'eight thousand one , eight thousand two , eight thousand three , eight thousand four , eight thousand five , eight thousand six , eight thousand seven , eight thousand eight , eight thousand nine , eight thousand ten , eight thousand eleven , eight thousand twelve , eight thousand thirteen , eight thousand fourteen , eight thousand fifteen , eight thousand sixteen , eight thousand seventeen , eight thousand'
In [54]:
v.textify(x2[0])
Out[54]:
'thousand eighteen , eight thousand nineteen , eight thousand twenty , eight thousand twenty one , eight thousand twenty two , eight thousand twenty three , eight thousand twenty four , eight thousand twenty five , eight thousand twenty six , eight thousand twenty seven , eight thousand twenty eight , eight thousand twenty nine , eight thousand thirty , eight thousand thirty one , eight thousand thirty two ,'
In [55]:
v.textify(x3[0])
Out[55]:
'eight thousand thirty three , eight thousand thirty four , eight thousand thirty five , eight thousand thirty six , eight thousand thirty seven , eight thousand thirty eight , eight thousand thirty nine , eight thousand forty , eight thousand forty one , eight thousand forty two , eight thousand forty three , eight thousand forty four , eight thousand forty five , eight thousand forty six , eight'
In [56]:
v.textify(x1[1])
Out[56]:
', eight thousand forty six , eight thousand forty seven , eight thousand forty eight , eight thousand forty nine , eight thousand fifty , eight thousand fifty one , eight thousand fifty two , eight thousand fifty three , eight thousand fifty four , eight thousand fifty five , eight thousand fifty six , eight thousand fifty seven , eight thousand fifty eight , eight thousand fifty nine ,'
In [57]:
v.textify(x2[1])
Out[57]:
'eight thousand sixty , eight thousand sixty one , eight thousand sixty two , eight thousand sixty three , eight thousand sixty four , eight thousand sixty five , eight thousand sixty six , eight thousand sixty seven , eight thousand sixty eight , eight thousand sixty nine , eight thousand seventy , eight thousand seventy one , eight thousand seventy two , eight thousand seventy three , eight thousand'
In [58]:
v.textify(x3[1])
Out[58]:
'seventy four , eight thousand seventy five , eight thousand seventy six , eight thousand seventy seven , eight thousand seventy eight , eight thousand seventy nine , eight thousand eighty , eight thousand eighty one , eight thousand eighty two , eight thousand eighty three , eight thousand eighty four , eight thousand eighty five , eight thousand eighty six , eight thousand eighty seven , eight thousand eighty'
In [59]:
v.textify(x3[-1])
Out[59]:
'ninety , nine thousand nine hundred ninety one , nine thousand nine hundred ninety two , nine thousand nine hundred ninety three , nine thousand nine hundred ninety four , nine thousand nine hundred ninety five , nine thousand nine hundred ninety six , nine thousand nine hundred ninety seven , nine thousand nine hundred ninety eight , nine thousand nine hundred ninety nine xxbos eight thousand one , eight'
In [60]:
data.show_batch(ds_type=DatasetType.Valid)

We will iteratively consider a few different models, building up to a more traditional RNN.

Single fully connected model

In [61]:
data = src.databunch(bs=bs, bptt=3)
In [62]:
x,y = data.one_batch()
x.shape,y.shape
Out[62]:
(torch.Size([64, 3]), torch.Size([64, 3]))
In [63]:
nv = len(v.itos); nv
Out[63]:
39
In [64]:
nh=64
In [65]:
def loss4(input,target): return F.cross_entropy(input, target[:,-1])
def acc4 (input,target): return accuracy(input, target[:,-1])
In [68]:
x[:,0]
Out[68]:
tensor([13, 13, 10,  9, 18,  9, 11, 11, 13, 19, 16, 23, 24,  9, 12,  9, 13, 14,
        15, 11, 10, 22, 15,  9, 10, 14, 11, 16, 10, 28, 11,  9, 20,  9, 15, 15,
        11, 18, 10, 28, 23, 24,  9, 16, 10, 16, 19, 20, 12, 10, 22, 16, 17, 17,
        17, 11, 24, 10,  9, 15, 16,  9, 18, 11])

Layer names:

  • i_h: input to hidden
  • h_h: hidden to hidden
  • h_o: hidden to output
  • bn: batchnorm
In [67]:
class Model0(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)  # green arrow
        self.h_h = nn.Linear(nh,nh)     # brown arrow
        self.h_o = nn.Linear(nh,nv)     # blue arrow
        self.bn = nn.BatchNorm1d(nh)
        
    def forward(self, x):
        h = self.bn(F.relu(self.i_h(x[:,0])))
        if x.shape[1]>1:
            h = h + self.i_h(x[:,1])
            h = self.bn(F.relu(self.h_h(h)))
        if x.shape[1]>2:
            h = h + self.i_h(x[:,2])
            h = self.bn(F.relu(self.h_h(h)))
        return self.h_o(h)
In [69]:
learn = Learner(data, Model0(), loss_func=loss4, metrics=acc4)
In [70]:
learn.fit_one_cycle(6, 1e-4)

Same thing with a loop

Let's refactor this to use a for-loop. This does the same thing as before:

In [72]:
class Model1(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)  # green arrow
        self.h_h = nn.Linear(nh,nh)     # brown arrow
        self.h_o = nn.Linear(nh,nv)     # blue arrow
        self.bn = nn.BatchNorm1d(nh)
        
    def forward(self, x):
        h = torch.zeros(x.shape[0], nh).to(device=x.device)
        for i in range(x.shape[1]):
            h = h + self.i_h(x[:,i])
            h = self.bn(F.relu(self.h_h(h)))
        return self.h_o(h)

This is the difference between unrolled (what we had before) and rolled (what we have now) RNN diagrams:

In [73]:
learn = Learner(data, Model1(), loss_func=loss4, metrics=acc4)
In [74]:
learn.fit_one_cycle(6, 1e-4)

Our accuracy is about the same, since we are doing the same thing as before.

Multi fully connected model

Before, we were just predicting the last word in a line of text. Given 70 tokens, what is token 71? That approach was throwing away a lot of data. Why not predict token 2 from token 1, then predict token 3, then predict token 4, and so on? We will modify our model to do this.

In [75]:
data = src.databunch(bs=bs, bptt=20)
In [76]:
x,y = data.one_batch()
x.shape,y.shape
Out[76]:
(torch.Size([64, 20]), torch.Size([64, 20]))
In [77]:
class Model2(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)
        self.h_h = nn.Linear(nh,nh)
        self.h_o = nn.Linear(nh,nv)
        self.bn = nn.BatchNorm1d(nh)
        
    def forward(self, x):
        h = torch.zeros(x.shape[0], nh).to(device=x.device)
        res = []
        for i in range(x.shape[1]):
            h = h + self.i_h(x[:,i])
            h = F.relu(self.h_h(h))
            res.append(self.h_o(self.bn(h)))
        return torch.stack(res, dim=1)
In [78]:
learn = Learner(data, Model2(), metrics=accuracy)
In [79]:
learn.fit_one_cycle(10, 1e-4, pct_start=0.1)

Note that our accuracy is worse now, because we are doing a harder task. When we predict word k (k<70), we have less history to help us then when we were only predicting word 71.

Maintain state

To address this issue, let's keep the hidden state from the previous line of text, so we are not starting over again on each new line of text.

In [80]:
class Model3(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)
        self.h_h = nn.Linear(nh,nh)
        self.h_o = nn.Linear(nh,nv)
        self.bn = nn.BatchNorm1d(nh)
        self.h = torch.zeros(bs, nh).cuda()
        
    def forward(self, x):
        res = []
        h = self.h
        for i in range(x.shape[1]):
            h = h + self.i_h(x[:,i])
            h = F.relu(self.h_h(h))
            res.append(self.bn(h))
        self.h = h.detach()
        res = torch.stack(res, dim=1)
        res = self.h_o(res)
        return res
In [81]:
learn = Learner(data, Model3(), metrics=accuracy)
In [82]:
learn.fit_one_cycle(20, 3e-3)

Now we are getting greater accuracy than before!

nn.RNN

Let's refactor the above to use PyTorch's RNN. This is what you would use in practice, but now you know the inside details!

In [ ]:
class Model4(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)
        self.rnn = nn.RNN(nh,nh, batch_first=True)
        self.h_o = nn.Linear(nh,nv)
        self.bn = BatchNorm1dFlat(nh)
        self.h = torch.zeros(1, bs, nh).cuda()
        
    def forward(self, x):
        res,h = self.rnn(self.i_h(x), self.h)
        self.h = h.detach()
        return self.h_o(self.bn(res))
In [ ]:
learn = Learner(data, Model4(), metrics=accuracy)
In [ ]:
learn.fit_one_cycle(20, 3e-3)

2-layer GRU

When you have long time scales and deeper networks, these become impossible to train. One way to address this is to add mini-NN to decide how much of the green arrow and how much of the orange arrow to keep. These mini-NNs can be GRUs or LSTMs. We will cover more details of this in a later lesson.

In [83]:
class Model5(nn.Module):
    def __init__(self):
        super().__init__()
        self.i_h = nn.Embedding(nv,nh)
        self.rnn = nn.GRU(nh, nh, 2, batch_first=True)
        self.h_o = nn.Linear(nh,nv)
        self.bn = BatchNorm1dFlat(nh)
        self.h = torch.zeros(2, bs, nh).cuda()
        
    def forward(self, x):
        res,h = self.rnn(self.i_h(x), self.h)
        self.h = h.detach()
        return self.h_o(self.bn(res))
In [ ]:
learn = Learner(data, Model5(), metrics=accuracy)
In [ ]:
learn.fit_one_cycle(10, 1e-2)

Connection to ULMFit

In the previous lesson, we were essentially swapping out self.h_o with a classifier in order to do classification on text.

fin

RNNs are just a refactored, fully-connected neural network.

You can use the same approach for any sequence labeling task (part of speech, classifying whether material is sensitive,..)

In [1]:
import jovian
In [ ]:
jovian.commit()
[jovian] Saving notebook..
In [ ]: