RNN
January 19, 2024
CIRCLING BACK TO RECURRENT MODELS OF LANGUAGE
[2211.01848]
Новая SOTA от DeepMind в мире рекурренток. До этого 4 года в топах бенчмарков держался Mogrifier LSTM. Речь идёт о бенчмарках Penn Treebank, WikiText2 где модели сравниваются обучаясь end2end на этих датасетах, без дополнительных данных.
Идея статьи заключается в модификации LSTM (назвали Rewired LSTM) и добавлении Mogrifier.
Mogrifier - техника при которой входной вектор x и hidden_state итеративно модулируют друг друга. Для этого дополнительно вводятся две матрицы Q ∈ m×n и R ∈ n×m, которые авторы предлагают факторизовать на матрицы с меньшим рангом Qleft ∈ m×k, Qright ∈ k×n, где k < min(m, n) - ранг
class LSTMCell(nn.LSTMCell):
def __init__(self, input_size, hidden_size, bias=True):
super().__init__(input_size, hidden_size, bias)
def forward(self, input: Tensor, hidden=None) -> tuple[Tensor, Tensor]:
if hidden is None:
h_prev = input.new_zeros(input.size(0), self.hidden_size)
c_prev = input.new_zeros(input.size(0), self.hidden_size)
else:
h_prev, c_prev = hidden
gates = F.linear(input, self.weight_ih, self.bias_ih) \
+ F.linear(h_prev, self.weight_hh, self.bias_hh)
i, j, f, o = gates.chunk(4, 1)
i = torch.sigmoid(i)
j = torch.tanh(j)
f = torch.sigmoid(f)
c = f*c_prev + i*j
o = torch.sigmoid(o)
h = o * torch.tanh(c)
return h, cclass RLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.linear1 = nn.Linear(input_size+hidden_size, hidden_size*2)
self.linear2 = nn.Linear(hidden_size*2, hidden_size)
self.linear3 = nn.Linear(hidden_size, hidden_size)
def forward(self, input: Tensor, hidden=None) -> tuple[Tensor, Tensor]:
if hidden is None:
h_prev = input.new_zeros(input.size(0), self.hidden_size)
c_prev = input.new_zeros(input.size(0), self.hidden_size)
else:
h_prev, c_prev = hidden
ij_gates = self.linear1(torch.cat([input, h_prev], dim=-1))
i, j = torch.chunk(ij_gates, chunks=2, dim=1)
i, j = torch.sigmoid(i), torch.tanh(j)
f = self.linear2(torch.cat([i*j, h_prev], dim=-1))
f = torch.sigmoid(f)
c = f*c_prev + torch.min(i, 1-f)*j
o = torch.sigmoid(self.linear3(c))
h = o * torch.tanh(c)
return h, cclass MogRLSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, mog_iterations=0, rank=0):
super().__init__()
assert rank < min(input_size, hidden_size)
assert mog_iterations == 0 or mog_iterations > 1
self.input_size = input_size
self.hidden_size = hidden_size
self.mog_iterations = mog_iterations
self.rank = rank
self.linear1 = nn.Linear(input_size+hidden_size, hidden_size*2)
self.linear2 = nn.Linear(hidden_size*2, hidden_size)
self.linear3 = nn.Linear(hidden_size, hidden_size)
if rank == 0 and mog_iterations > 0:
self.Q = nn.Linear(hidden_size, input_size)
self.R = nn.Linear(input_size, hidden_size)
elif mog_iterations > 0:
self.Ql = nn.Parameter(torch.Tensor(hidden_size, rank))
self.Qr = nn.Parameter(torch.Tensor(rank, input_size))
self.Qb = nn.Parameter(torch.zeros(input_size))
self.Rl = nn.Parameter(torch.Tensor(input_size, rank))
self.Rr = nn.Parameter(torch.Tensor(rank, hidden_size))
self.Rb = nn.Parameter(torch.zeros(hidden_size))
self._init_mog_low_rank_matrices()
def _init_mog_low_rank_matrices(self):
from functools import partial
def make_low_rank_factorization_initializer(fan_in, rank):
variance = 1.0 / fan_in
# Each element of a*b (the low rank matrices) is the sum of 'rank'
# terms, each of which is a product of an element from 'a' and 'b'.
stddev = math.sqrt(math.sqrt(variance / rank))
return partial(nn.init.trunc_normal_, std=stddev)
q_init = make_low_rank_factorization_initializer(self.hidden_size, self.rank)
r_init = make_low_rank_factorization_initializer(self.input_size, self.rank)
q_init(self.Ql)
q_init(self.Qr)
r_init(self.Rl)
r_init(self.Rr)
def mogrify(self, xt, ht):
# |xt| : (batch_size, input_size)
# |ht| : (batch_size, hidden_size)
for i in range(1,self.mog_iterations+1):
if (i % 2 == 0):
if self.rank == 0:
ht = (2*torch.sigmoid(self.R(xt))) * ht
else:
_x = xt@self.Rl
# |_x| : (batch_size, rank)
_x = _x@self.Rr + self.Rb
# |_x| : (batch_size, hidden_size)
ht = 2*torch.sigmoid(_x) * ht
else:
if self.rank == 0:
xt = (2*torch.sigmoid(self.Q(ht))) * xt
else:
_h = ht@self.Ql
# |_h| : (batch_size, rank)
_h = _h@self.Qr + self.Qb
# |_h| : (batch_size, input_size)
xt = 2*torch.sigmoid(_h) * xt
return xt, ht
def forward(self, input: Tensor, hidden=None) -> tuple[Tensor, Tensor]:
if hidden is None:
h_prev = input.new_zeros(input.size(0), self.hidden_size)
c_prev = input.new_zeros(input.size(0), self.hidden_size)
else:
h_prev, c_prev = hidden
input, h_prev = self.mogrify(input, h_prev)
ij_gates = self.linear1(torch.cat([input, h_prev], dim=-1))
i, j = torch.chunk(ij_gates, chunks=2, dim=1)
i, j = torch.sigmoid(i), torch.tanh(j)
f = self.linear2(torch.cat([i*j, h_prev], dim=-1))
f = torch.sigmoid(f)
c = f*c_prev + torch.min(i, 1-f)*j
o = torch.sigmoid(self.linear3(c))
h = o * torch.tanh(c)
return h, c