RNN
January 19, 2024

CIRCLING BACK TO RECURRENT MODELS OF LANGUAGE

[2211.01848]
Новая SOTA от DeepMind в мире рекурренток. До этого 4 года в топах бенчмарков держался Mogrifier LSTM. Речь идёт о бенчмарках Penn Treebank, WikiText2 где модели сравниваются обучаясь end2end на этих датасетах, без дополнительных данных.

Идея статьи заключается в модификации LSTM (назвали Rewired LSTM) и добавлении Mogrifier.

Mogrifier - техника при которой входной вектор x и hidden_state итеративно модулируют друг друга. Для этого дополнительно вводятся две матрицы Q ∈ m×n и R ∈ n×m, которые авторы предлагают факторизовать на матрицы с меньшим рангом Qleft ∈ m×k, Qright ∈ k×n, где k < min(m, n) - ранг

Ячейка LSTM:

class LSTMCell(nn.LSTMCell):
    def __init__(self, input_size, hidden_size, bias=True):
        super().__init__(input_size, hidden_size, bias)

    def forward(self, input: Tensor, hidden=None) -> tuple[Tensor, Tensor]:
        if hidden is None:
            h_prev = input.new_zeros(input.size(0), self.hidden_size)
            c_prev = input.new_zeros(input.size(0), self.hidden_size)
        else:
            h_prev, c_prev = hidden

        gates = F.linear(input, self.weight_ih, self.bias_ih) \
                + F.linear(h_prev, self.weight_hh, self.bias_hh)

        i, j, f, o = gates.chunk(4, 1)

        i = torch.sigmoid(i)
        j = torch.tanh(j)
        f = torch.sigmoid(f)
        c = f*c_prev + i*j
        o = torch.sigmoid(o)
        h = o * torch.tanh(c)

        return h, c

Ячейка RLSTM:

class RLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.linear1 = nn.Linear(input_size+hidden_size, hidden_size*2)
        self.linear2 = nn.Linear(hidden_size*2, hidden_size)
        self.linear3 = nn.Linear(hidden_size, hidden_size)

    def forward(self, input: Tensor, hidden=None) -> tuple[Tensor, Tensor]:
        if hidden is None:
            h_prev = input.new_zeros(input.size(0), self.hidden_size)
            c_prev = input.new_zeros(input.size(0), self.hidden_size)
        else:
            h_prev, c_prev = hidden

        ij_gates = self.linear1(torch.cat([input, h_prev], dim=-1))
        i, j = torch.chunk(ij_gates, chunks=2, dim=1)
        i, j = torch.sigmoid(i), torch.tanh(j)

        f = self.linear2(torch.cat([i*j, h_prev], dim=-1))
        f = torch.sigmoid(f)

        c = f*c_prev + torch.min(i, 1-f)*j

        o = torch.sigmoid(self.linear3(c))
        h = o * torch.tanh(c)

        return h, c

MogrifierRLSTM

class MogRLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size, mog_iterations=0, rank=0):
        super().__init__()
        assert rank < min(input_size, hidden_size)
        assert mog_iterations == 0 or mog_iterations > 1

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.mog_iterations = mog_iterations
        self.rank = rank

        self.linear1 = nn.Linear(input_size+hidden_size, hidden_size*2)
        self.linear2 = nn.Linear(hidden_size*2, hidden_size)
        self.linear3 = nn.Linear(hidden_size, hidden_size)
        
        if rank == 0 and mog_iterations > 0:
            self.Q = nn.Linear(hidden_size, input_size)
            self.R = nn.Linear(input_size, hidden_size)
            
        elif mog_iterations > 0:
            self.Ql = nn.Parameter(torch.Tensor(hidden_size, rank))
            self.Qr = nn.Parameter(torch.Tensor(rank, input_size))
            self.Qb = nn.Parameter(torch.zeros(input_size))

            self.Rl = nn.Parameter(torch.Tensor(input_size, rank))
            self.Rr = nn.Parameter(torch.Tensor(rank, hidden_size))
            self.Rb = nn.Parameter(torch.zeros(hidden_size))

            self._init_mog_low_rank_matrices()

    def _init_mog_low_rank_matrices(self):
        from functools import partial

        def make_low_rank_factorization_initializer(fan_in, rank):
            variance = 1.0 / fan_in
            # Each element of a*b (the low rank matrices) is the sum of 'rank'
            # terms, each of which is a product of an element from 'a' and 'b'.
            stddev = math.sqrt(math.sqrt(variance / rank))
            return partial(nn.init.trunc_normal_, std=stddev)

        q_init = make_low_rank_factorization_initializer(self.hidden_size, self.rank)
        r_init = make_low_rank_factorization_initializer(self.input_size, self.rank)
        q_init(self.Ql)
        q_init(self.Qr)
        r_init(self.Rl)
        r_init(self.Rr)

    def mogrify(self, xt, ht):
        # |xt| : (batch_size, input_size)
        # |ht| : (batch_size, hidden_size)

        for i in range(1,self.mog_iterations+1):
            if (i % 2 == 0):
                if self.rank == 0:
                    ht = (2*torch.sigmoid(self.R(xt))) * ht
                else:
                    _x = xt@self.Rl
                    # |_x| : (batch_size, rank)
                    _x = _x@self.Rr + self.Rb
                    # |_x| : (batch_size, hidden_size)
                    ht = 2*torch.sigmoid(_x) * ht

            else:
                if self.rank == 0:
                    xt = (2*torch.sigmoid(self.Q(ht))) * xt
                else:
                    _h = ht@self.Ql
                    # |_h| : (batch_size, rank)
                    _h = _h@self.Qr + self.Qb
                    # |_h| : (batch_size, input_size)
                    xt = 2*torch.sigmoid(_h) * xt

        return xt, ht
    
    def forward(self, input: Tensor, hidden=None) -> tuple[Tensor, Tensor]:
        if hidden is None:
            h_prev = input.new_zeros(input.size(0), self.hidden_size)
            c_prev = input.new_zeros(input.size(0), self.hidden_size)
        else:
            h_prev, c_prev = hidden

        input, h_prev = self.mogrify(input, h_prev)
        ij_gates = self.linear1(torch.cat([input, h_prev], dim=-1))
        i, j = torch.chunk(ij_gates, chunks=2, dim=1)
        i, j = torch.sigmoid(i), torch.tanh(j)

        f = self.linear2(torch.cat([i*j, h_prev], dim=-1))
        f = torch.sigmoid(f)

        c = f*c_prev + torch.min(i, 1-f)*j

        o = torch.sigmoid(self.linear3(c))
        h = o * torch.tanh(c)

        return h, c

Результаты: