transformer
January 19, 2024

Сложности обучения трансформеров

Вольный пересказ разбора статьи (бтв там есть ошибки трактовок) c дополнениями от себя. Сначала про то какие есть фишки которые используются для стабильного трейна, потом уже про то зачем они и какие проблемы решают.

Introduction to the transformers

Pre/Post-LayerNorm transformer

Сама архитектура имеет две вариации, оригинальная архитектура из attention is all you need является Post-LN, все LLM на данный момент без исключений являются Pre-LN. Вкратце первый перфомит на маленьких трансформерах (6 и меньше слоёв) и расходится при большом количестве слоёв, а второй обучается стабильнее.

[картинка 1] Схематическое изображение архитектуры энкодер-декодер трансформеров (a) Post layer norm (b) Pre layer norm
class TransformerEncoderLayer(nn.Module):
    def forward(self, 
                inputs:Tensor, 
                key_padding_mask:Optional[Tensor]=None, 
                attn_mask:Optional[Tensor]=None,
                ) -> Tensor:
        # |inputs| : (batch_size, seq_len, d_model)
        # |key_padding_mask| : (batch_size, seq_len)
        # |attn_mask| : (batch_size, seq_len, seq_len)

        if self.norm == 'pre':
            # Forward inputs over first layernorm
            x = self.layernorm1(inputs)
            # Calculate self attention scores
            attn_output, _ = self.attn(
                query=x, key=x, value=x,
                key_padding_mask=key_padding_mask, 
                attn_mask=attn_mask
            )
            
            shortcut = inputs + attn_output # First schortcut
            x = shortcut + self.ffn(self.layernorm2(x)) # Second schortcut
            # |x| : (batch_size, seq_len, d_model)

        elif self.norm == 'post':
            attn_output, _ = self.attn(
                query=inputs, key=inputs, value=inputs, 
                key_padding_mask=key_padding_mask, attn_mask=attn_mask
            )
            x = self.layernorm1(inputs + attn_output) # First schortcut
            x = self.layernorm2(x + self.ffn(x)) # Second schortcut
            # |x| : (batch_size, seq_len, d_model)

        else: raise Exception
        
        return x

Residual connections

В каждом блоке трансформера используется два residual connection'а, они же shortcuts.

Layer normalization

Слой нормализации после каждого residual connection в Post-LN архитектуре и перед MHA, FFN в Pre-LN.

class LayerNorm(nn.Module):
    def __init__(self, hidden_size:int, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.bias = nn.Parameter(torch.zeros(hidden_size))
            
    def forward(self, x: Tensor) -> Tensor:
        # |x| : (batch_size, seq_len, d_model)

        # Calculate the E[x] of all elements
        mean = x.mean(dim=-1, keepdim=True)
        # |mean| : (batch_size, seq_len, 1)

        # Calculate the squared mean E[X^2] of all elements
        mean_x2 = (x ** 2).mean(dim=-1, keepdim=True)
        # Variance of all element Var[X] = E[X^2] - E[X]^2
        var = mean_x2 - mean ** 2
        # |var| : (batch_size, seq_len, 1)

        # Normalize x. Layernorm[X] = weight * (X - E[X])/(Var[X] + eps)^0.5 + bias
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        # Elementwise affine transformation
        x_norm = x_norm * self.weight + self.bias
        # |x_norm| : (batch_size, seq_len, d_model)

        return x_norm

Learning rate warm-up

Линейный прогрев LR с 0 до R первые t шагов оптимизатора,но как правило используется и learning rate decay, когда после шага t мы уменьшаем LR.

Adaptive optimizers

Трансформерам нужен adaptive оптимизатор, который накапливает квадраты градиентов, по типу Adam.

Scaled Dot Product

В трансформере мы измеряем похожесть токенов с помощью scaled dot product, это обычный dot product (скалярное произведение) деленный на корень из размерности векторов dk

Q,K - queris, keys, d_k - размерность вкторов Q,K

Training Tips

Scaled Dot Product

Для больших значений dk, размерностей векторов, результат обычного dot product сильно растёт в величине, толкая softmax в регионы где он имеет экстремально низкие градиенты. Чтобы проиллюстрировать, почему скалярные произведения становятся большими, предположим, что компоненты векторов q и k являются независимыми случайными переменными со средним значением 0 и дисперсией 1. Тогда их скалярное произведение, q · k = Σᵢᵈ qᵢkᵢ, имеет среднее значение 0 и дисперсию d. Таким образом скейля результат на √d мы получаем матрицу скоров со стандартным отклонением равным 1 (при условии что компоненты векторов q и k являются независимыми случайными переменными со средним значением 0 и дисперсией 1)

Residual connections

Считая градиенты слой за слоем градиенты могут становится очень маленькими (или очень большими):

[картинка 2] Chain rule

residual connections позволяет информации с более ранних слоёв легче пробираться до выхода, "перескакивать" через слои и таким образом борется с проблемой затухающих градиентов. Менее формально можно сказать что residual connection не трансформирует входной вектор x напрямую, а вносит модификацию во входной вектор x:

x = x + F(x)
[картинка 3] Схематическое изображение residual connection с функцией активации relu

Layer Normalization

Есть и обратная сторона медали в residual connection. В работе Zhang et al., 2019 показано, что выходная дисперсия residual connection растет экспоненциально с глубиной сети. Поэтому нам необходима нормализация для предотвращения градиентного взрыва для глубоких residual сетей. В трансформерах, для обработки последовательных данных, мы используем LayerNorm вместо BatchNorm, т.к. "batch statistics in transformers on NLP tasks have larger variations".

Learning rate warm-up

Нужен warmup из-за layernorm и Adam.

Выводы из статьи Xiong et al., 2020 про layer norm:

  1. LayerNorm скейлит градиенты
  2. В Post-LN скейлинг не зависит от количества слоёв трансформера L
  3. В Pre-LN градиенты параметров скейлятся на √L
  4. В Post-LN норма градиентов увеличиваются к слоям близким к выходу
  5. В Pre-LN норма градиентов остаётся та же, вне зависимости от слоя
[картинка 2] Норма градиентов FFN последнего слоя во время инициализации (начала обучения) По оси X количество слоёв, по оси Y норма градиентов последнего слоя
[картинка 3] Норма градиентов FFN каждого слоя во время инициализации (начала обучения) По оси X номер сля, по оси Y норма градиента слоя

Обучили Post-LN с разными параметрами:

[картинка 4] Сравнение разных warmup стратегий и оптимизаторов

Видно как сильно влияет выбор оптимизатор (об этом дальше), наличие warm-up и его продолжительность. Обучение не удастся с большим LR, хотя с довольно низким LR обучение пойдёт, но очень медленно:

[картинка 5] Модель сходится без warmup с lr=1e-4, в отличии от более высоких lr: 5e-4, 1e-3 и 1e-3 но с коротким warmup

Вообще это касается больше Post-LN чем Pre-LN. Как видно на [картинке 6] Pre-LN без warm-up сходится быстрее чем Post-LN с warm-up.

[картинка 6] Pre-LN без warm-up сходится быстрее чем Post-LN с warm-up

Выводы из статьи:

1. У Post-LN в начале тренировки огромные градиенты параметров рядом с residual connection (т.е. FFN), которые быстро затухают по мере обучения (предполагаю это из-за Attention Entropy Collapse, об этом далее).

2. Layernorm скейлит градиенты:

[картинка 7] x - вход, d - размерность эмбединга, если норма x больше корня из размерности тогда бэкпроп через layernorm уменьшает величину градиентов. С несколькими слоями это может быстро привести к затуханию градиентов или их возрастанию

Liu et al. (2019) показали почему warmup нужен для стабилизации обучения с адаптивными оптимизаторами: RMSprop, Adam. Проблема в высокой дисперсии в начале обучения adaptive learning rate'ов (Beta 2 параметр в Adam), фиксится через warmup, низкий LR, RAdam, или первые N (около 2000) шагов не обновлять веса модели и моментум, обновлять только лернинг рейты, после чего warm-up больше не нужен (не нашёл в статье про Post или Pre LN идёт речь).
Huang et al., 2020
показали что эффект комбинируется, из-за высокой дисперсии градиентов в начале получаются большие обновления весов, что увеличивает норму эмбедингов поступающих в layernorm. d модели фиксированная и √d/||x|| начинает очень быстро становится меньше 1 что приводит к затухающим градиентам.

Adaptive optimizers

SGD на практике не работает с трансформерами [картинка 4] из-за unbalanced gradients: дифференцируясь через MHA градиенты value Wv сильно больше query Wq и key Wk. Adam фиксит эту проблему, т.к. хранит квадраты градиентов и использует их как learning rate для каждого параметра. Но и вызывает проблему описанную выше
Напоминаю о том как выглядит MHA (с одной головой внимания):

class SelfAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim: Optional[int] = None, dropout: float = 0.):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(dropout)

        self.sqrt = torch.sqrt(torch.tensor(hidden_dim, dtype=torch.float32))
        self.Wq = nn.Linear(input_dim, hidden_dim)
        self.Wk = nn.Linear(input_dim, hidden_dim)
        self.Wv = nn.Linear(input_dim, hidden_dim)

    def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None,
                attention_mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
        """
        Args:
            x (torch.Tensor): Input sequence of shape (batch_size, seq_len, input_dim).
            key_padding_mask (torch.Tensor): Padding mask of shape (batch_size, seq_len),
                                             where 1 indicates padding positions and 0 indicates non-padding positions.
                                             Defaults to None.
            attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
                                           where 1 indicates positions to attend and 0 indicates positions to ignore.
                                           Defaults to None.

        Returns:
            torch.Tensor: Output sequence of shape (batch_size, seq_len, out_dim).
            torch.Tensor: Attention weights of shape (batch_size, seq_len, seq_len).
        """
        q: Tensor = self.Wq(x)  # [batch_size, seq_len, hidden_dim]
        k: Tensor = self.Wk(x)  # [batch_size, seq_len, hidden_dim]
        v: Tensor = self.Wv(x)  # [batch_size, seq_len, hidden_dim]

        # Calculate scores: QK.T / sqrt
        scores = torch.bmm(q, k.transpose(1, 2)) / self.sqrt
        # |scores| - (batch_size, seq_len, seq_len)

        # Apply padding and attention masks
        scores = self.apply_masks(scores, key_padding_mask, attention_mask)

        # Apply softmax to scores: softmax(QK.T/sqrt)
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # Multiply attention weights with value vectors to get context vectors
        context_vectors = torch.bmm(attention_weights, v) # softmax(QK.T/sqrt)V
        return context_vectors, attention_weights

    def apply_masks(self, scores: Tensor, key_padding_mask: Optional[Tensor] = None,
                    attention_mask: Optional[Tensor] = None) -> Tensor:
        ...

Amplification Effect

Liu et al. (2020) показали что небольшие изменение в параметрах ведут к большим изменениям на выходе из трансформера (из-за MHA), назвали это Amplification Effect. В Pre-LN зависимость этих флуктуаций от количества слоёв N обладает следующим свойством O(lon N), в Post-LN - O(N). Так же показали что Post-LN сильно зависим от residual branches, т.е. LayerNorm[X + MHA[X]] зависит больше от MHA[X] чем от X, в то время Как Pre-LN по ходу обучения учится зависеть от residual branches, отсюда и разница в Amplification Effect и стабильности обучения.

В статье они предлагают технику Admin что бы пофиксить Post-LN, суть метода в том что бы добавить обучаемый параметр Ψ размерностью 1 x D, D - размерность модели, таким образом первый LayerNorm выглядит следующим образом: LayerNorm[X ⊙ Ψ + MHA[X]], ⊙ - покомпонентное произведение. Для Post-LN инициализируется единицами. Для Pre-LN значением omega_value :

omega_value = (num_res_layers + 1) / math.log(num_res_layers + 1) - 1

где num_res_layers - 2 * количество слоёв трансформера.


Есть и метод ReZero, то же самое что и Admin, но вместо вектора, обучаемый скаляр a, инициализируемый нулём, архитектура выглядит так: X + aMHA[X]

Есть метод DeepNorm, который позволяет скейлить Post-LN до 1000 слоёв, для этого определённым спобом инициализируют веса и модернизируют второй LayerNorm, вводя константу a:

LayerNorm[X∗α + ffn(x)]

Инициализация:

def deepnorm_init(w):
    if w is ['ffn', 'v_proj', 'out_proj']:
        nn.init.xavier_normal_(w, gain=β)
    elif w is ['q_proj', 'k_proj']:
        nn.init.xavier_normal_(w, gain=1)

Значения a, β высчитываются следующим образом:

[картинка 8] N-количество слоёв энкодера, M-количество слоёв декодера

Attention Entropy Collapse

Интересная работа apple, показали что стабильность обучения трансформеров тесно связана с энтропией атеншена:

Attention Entropy p - attention_weights, которые после softmax

Показали, что низкая энтропия ведёт к расхождению и нестабильности тренировки. Пример тензора с высокой и низкой энтропией:

>>> tensor = torch.rand(4, 4)
>>> tensor1 = (tensor/100).softmax(dim=-1)
>>> print(tensor1)
tensor([[0.2499, 0.2502, 0.2499, 0.2500],
        [0.2513, 0.2492, 0.2489, 0.2507],
        [0.2503, 0.2507, 0.2499, 0.2492],
        [0.2510, 0.2505, 0.2488, 0.2497]])
>>> attention_entropy = -(tensor1 * torch.log(tensor1)).mean()
>>> print(attention_entropy) # High Entropy
tensor(0.3466)

>>> tensor2 = (tensor/0.01).softmax(dim=-1)
>>> print(tensor2.round(decimals=4))
tensor([[0.0000, 0.9999, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0000, 0.0000, 0.0000],
        [1.0000, 0.0000, 0.0000, 0.0000]])
>>> attention_entropy = -(tensor2 * torch.log(tensor2)).mean()
>>> print(attention_entropy) # Low Entropy, close to 0
tensor(3.5994e-05)

Высокий Learning Rate ведёт к attention entropy collapse (энтропия падает до 0):

[картинка 9] Обучение трансформера с 5.e-4 lr и 1e-3 lr и их Attention Entropy

Так же энтропию можно контролировать введя temperature параметр перед softmax:

attention_weights = F.softmax(scores/self.temperature, dim=-1)

Показали что уменьшая температуру, с 1 до 0.1, а тем самым уменьшая энтропию, на 10 и на 50 эпохе, трансформер расходится:

[картинка 10] Обычное обучение трансформера: черная линия, переключение температуры с 1 на 0.1 на 10 эпохе: синяя линия, переключение температуры с 1 на 0.1 на 50 эпохе: оранжевая линия,

Нижняя граница энтропия выражается через Tσe^(-σ), где T - длина последовательности, X - входной вектор, Wk, Wq - матрицы key и query MHA

Предлагают σReparam, та же спектральная нормализация с обучаемым коэффициентом γ

[картинка 11] σReparam, σ(W) ∈ R спектральная норма W, γ ∈ R обучаемый параметр, инициализированный единицой

σ(W) предлагают считать через power iteration:

# Parameters. W: weight matrix, shape (d, c); gamma: the learned spectral norm, shape (1,)
# Buffers. u: shape (d,), v: shape (c,), the left and right singular vectors of W
if init: # initialize u, v as random unit vectors and gamma to 1
    u = randn(d)
    u = u / u.norm(dim=0)
    v = randn(c)
    v = v / v.norm(dim=0)
    gamma = ones(1)
    
if training: # if in the training mode, perform one step of power iteration first
    with torch.no_grad():
        u = W.mv(v)
        u = u / u.norm(dim=0)
        v = W.T.mv(u)
        v = v / v.norm(dim=0)
        
sigma = einsum(’d,dc,c->’, u, W, v)
W_hat = gamma / sigma * W # the effective spectral norm of W_hat would be gamma

σReparam применяют для всех весов линейных и свёрточных слоёв, убирая при этом LayerNorm

[картинка 12] Сравнение разных подходов на задаче ASR
[картинка 13] Сравнение на задаче Language Modeling, за бейзлайн для сравнения берут https://arxiv.org/abs/1809.10853
[картинка 14] Сравнение σReparam с DeepNorm