chapter_recurrent-modern/gru_origin.md
:label:sec_gru
In :numref:sec_bptt,
we discussed how gradients are calculated
in RNNs.
In particular we found that long products of matrices can lead
to vanishing or exploding gradients.
Let us briefly think about what such
gradient anomalies mean in practice:
A number of methods have been proposed to address this. One of the earliest is long short-term memory :cite:Hochreiter.Schmidhuber.1997 which we
will discuss in :numref:sec_lstm. The gated recurrent unit (GRU)
:cite:Cho.Van-Merrienboer.Bahdanau.ea.2014 is a slightly more streamlined
variant that often offers comparable performance and is significantly faster to
compute :cite:Chung.Gulcehre.Cho.ea.2014.
Due to its simplicity, let us start with the GRU.
The key distinction between vanilla RNNs and GRUs is that the latter support gating of the hidden state. This means that we have dedicated mechanisms for when a hidden state should be updated and also when it should be reset. These mechanisms are learned and they address the concerns listed above. For instance, if the first token is of great importance we will learn not to update the hidden state after the first observation. Likewise, we will learn to skip irrelevant temporary observations. Last, we will learn to reset the latent state whenever needed. We discuss this in detail below.
The first thing we need to introduce are the reset gate and the update gate. We engineer them to be vectors with entries in $(0, 1)$ such that we can perform convex combinations. For instance, a reset gate would allow us to control how much of the previous state we might still want to remember. Likewise, an update gate would allow us to control how much of the new state is just a copy of the old state.
We begin by engineering these gates.
:numref:fig_gru_1 illustrates the inputs for both
the reset and update gates in a GRU, given the input
of the current time step
and the hidden state of the previous time step.
The outputs of two gates
are given by two fully-connected layers
with a sigmoid activation function.
:label:fig_gru_1
Mathematically, for a given time step $t$, suppose that the input is a minibatch $\mathbf{X}t \in \mathbb{R}^{n \times d}$ (number of examples: $n$, number of inputs: $d$) and the hidden state of the previous time step is $\mathbf{H}{t-1} \in \mathbb{R}^{n \times h}$ (number of hidden units: $h$). Then, the reset gate $\mathbf{R}_t \in \mathbb{R}^{n \times h}$ and update gate $\mathbf{Z}_t \in \mathbb{R}^{n \times h}$ are computed as follows:
$$ \begin{aligned} \mathbf{R}t = \sigma(\mathbf{X}t \mathbf{W}{xr} + \mathbf{H}{t-1} \mathbf{W}_{hr} + \mathbf{b}r),\ \mathbf{Z}t = \sigma(\mathbf{X}t \mathbf{W}{xz} + \mathbf{H}{t-1} \mathbf{W}{hz} + \mathbf{b}_z), \end{aligned} $$
where $\mathbf{W}{xr}, \mathbf{W}{xz} \in \mathbb{R}^{d \times h}$ and
$\mathbf{W}{hr}, \mathbf{W}{hz} \in \mathbb{R}^{h \times h}$ are weight
parameters and $\mathbf{b}_r, \mathbf{b}_z \in \mathbb{R}^{1 \times h}$ are
biases.
Note that broadcasting (see :numref:subsec_broadcasting) is triggered during the summation.
We use sigmoid functions (as introduced in :numref:sec_mlp) to transform input values to the interval $(0, 1)$.
Next, let us
integrate the reset gate $\mathbf{R}_t$ with
the regular latent state updating mechanism
in :eqref:rnn_h_with_state.
It leads to the following
candidate hidden state
$\tilde{\mathbf{H}}_t \in \mathbb{R}^{n \times h}$ at time step $t$:
$$\tilde{\mathbf{H}}t = \tanh(\mathbf{X}t \mathbf{W}{xh} + \left(\mathbf{R}t \odot \mathbf{H}{t-1}\right) \mathbf{W}{hh} + \mathbf{b}_h),$$
:eqlabel:gru_tilde_H
where $\mathbf{W}{xh} \in \mathbb{R}^{d \times h}$ and $\mathbf{W}{hh} \in \mathbb{R}^{h \times h}$ are weight parameters, $\mathbf{b}_h \in \mathbb{R}^{1 \times h}$ is the bias, and the symbol $\odot$ is the Hadamard (elementwise) product operator. Here we use a nonlinearity in the form of tanh to ensure that the values in the candidate hidden state remain in the interval $(-1, 1)$.
The result is a candidate since we still need to incorporate the action of the update gate.
Comparing with :eqref:rnn_h_with_state,
now the influence of the previous states
can be reduced with the
elementwise multiplication of
$\mathbf{R}t$ and $\mathbf{H}{t-1}$
in :eqref:gru_tilde_H.
Whenever the entries in the reset gate $\mathbf{R}_t$ are close to 1, we recover a vanilla RNN such as in :eqref:rnn_h_with_state.
For all entries of the reset gate $\mathbf{R}_t$ that are close to 0, the candidate hidden state is the result of an MLP with $\mathbf{X}_t$ as the input. Any pre-existing hidden state is thus reset to defaults.
:numref:fig_gru_2 illustrates the computational flow after applying the reset gate.
:label:fig_gru_2
Finally, we need to incorporate the effect of the update gate $\mathbf{Z}_t$. This determines the extent to which the new hidden state $\mathbf{H}t \in \mathbb{R}^{n \times h}$ is just the old state $\mathbf{H}{t-1}$ and by how much the new candidate state $\tilde{\mathbf{H}}_t$ is used. The update gate $\mathbf{Z}t$ can be used for this purpose, simply by taking elementwise convex combinations between both $\mathbf{H}{t-1}$ and $\tilde{\mathbf{H}}_t$. This leads to the final update equation for the GRU:
$$\mathbf{H}_t = \mathbf{Z}t \odot \mathbf{H}{t-1} + (1 - \mathbf{Z}_t) \odot \tilde{\mathbf{H}}_t.$$
Whenever the update gate $\mathbf{Z}_t$ is close to 1, we simply retain the old state. In this case the information from $\mathbf{X}_t$ is essentially ignored, effectively skipping time step $t$ in the dependency chain. In contrast, whenever $\mathbf{Z}_t$ is close to 0, the new latent state $\mathbf{H}_t$ approaches the candidate latent state $\tilde{\mathbf{H}}_t$. These designs can help us cope with the vanishing gradient problem in RNNs and better capture dependencies for sequences with large time step distances. For instance, if the update gate has been close to 1 for all the time steps of an entire subsequence, the old hidden state at the time step of its beginning will be easily retained and passed to its end, regardless of the length of the subsequence.
:numref:fig_gru_3 illustrates the computational flow after the update gate is in action.
:label:fig_gru_3
In summary, GRUs have the following two distinguishing features:
To gain a better understanding of the GRU model, let us implement it from scratch. We begin by reading the time machine dataset that we used in :numref:sec_rnn_scratch. The code for reading the dataset is given below.
from d2l import mxnet as d2l
from mxnet import np, npx
from mxnet.gluon import rnn
npx.set_np()
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
#@tab pytorch
from d2l import torch as d2l
import torch
from torch import nn
batch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
The next step is to initialize the model parameters.
We draw the weights from a Gaussian distribution
with standard deviation to be 0.01 and set the bias to 0. The hyperparameter num_hiddens defines the number of hidden units.
We instantiate all weights and biases relating to the update gate, the reset gate, the candidate hidden state,
and the output layer.
def get_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return np.random.normal(scale=0.01, size=shape, ctx=device)
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
np.zeros(num_hiddens, ctx=device))
W_xz, W_hz, b_z = three() # Update gate parameters
W_xr, W_hr, b_r = three() # Reset gate parameters
W_xh, W_hh, b_h = three() # Candidate hidden state parameters
# Output layer parameters
W_hq = normal((num_hiddens, num_outputs))
b_q = np.zeros(num_outputs, ctx=device)
# Attach gradients
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.attach_grad()
return params
#@tab pytorch
def get_params(vocab_size, num_hiddens, device):
num_inputs = num_outputs = vocab_size
def normal(shape):
return torch.randn(size=shape, device=device)*0.01
def three():
return (normal((num_inputs, num_hiddens)),
normal((num_hiddens, num_hiddens)),
d2l.zeros(num_hiddens, device=device))
W_xz, W_hz, b_z = three() # Update gate parameters
W_xr, W_hr, b_r = three() # Reset gate parameters
W_xh, W_hh, b_h = three() # Candidate hidden state parameters
# Output layer parameters
W_hq = normal((num_hiddens, num_outputs))
b_q = d2l.zeros(num_outputs, device=device)
# Attach gradients
params = [W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]
for param in params:
param.requires_grad_(True)
return params
Now we will define the hidden state initialization function init_gru_state. Just like the init_rnn_state function defined in :numref:sec_rnn_scratch, this function returns a tensor with a shape (batch size, number of hidden units) whose values are all zeros.
def init_gru_state(batch_size, num_hiddens, device):
return (np.zeros(shape=(batch_size, num_hiddens), ctx=device), )
#@tab pytorch
def init_gru_state(batch_size, num_hiddens, device):
return (torch.zeros((batch_size, num_hiddens), device=device), )
Now we are ready to define the GRU model. Its structure is the same as that of the basic RNN cell, except that the update equations are more complex.
def gru(inputs, state, params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = npx.sigmoid(np.dot(X, W_xz) + np.dot(H, W_hz) + b_z)
R = npx.sigmoid(np.dot(X, W_xr) + np.dot(H, W_hr) + b_r)
H_tilda = np.tanh(np.dot(X, W_xh) + np.dot(R * H, W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = np.dot(H, W_hq) + b_q
outputs.append(Y)
return np.concatenate(outputs, axis=0), (H,)
#@tab pytorch
def gru(inputs, state, params):
W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params
H, = state
outputs = []
for X in inputs:
Z = torch.sigmoid((X @ W_xz) + (H @ W_hz) + b_z)
R = torch.sigmoid((X @ W_xr) + (H @ W_hr) + b_r)
H_tilda = torch.tanh((X @ W_xh) + ((R * H) @ W_hh) + b_h)
H = Z * H + (1 - Z) * H_tilda
Y = H @ W_hq + b_q
outputs.append(Y)
return torch.cat(outputs, dim=0), (H,)
Training and prediction work in exactly the same manner as in :numref:sec_rnn_scratch.
After training,
we print out the perplexity on the training set
and the predicted sequence following
the provided prefixes "time traveller" and "traveller", respectively.
#@tab all
vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_params,
init_gru_state, gru)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
In high-level APIs, we can directly instantiate a GPU model. This encapsulates all the configuration detail that we made explicit above. The code is significantly faster as it uses compiled operators rather than Python for many details that we spelled out before.
gru_layer = rnn.GRU(num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
#@tab pytorch
num_inputs = vocab_size
gru_layer = nn.GRU(num_inputs, num_hiddens)
model = d2l.RNNModel(gru_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
rnn.RNN and rnn.GRU implementations with each other.:begin_tab:mxnet
Discussions
:end_tab:
:begin_tab:pytorch
Discussions
:end_tab: