Сложности обучения трансформеров
Вольный пересказ разбора статьи (бтв там есть ошибки трактовок) c дополнениями от себя. Сначала про то какие есть фишки которые используются для стабильного трейна, потом уже про то зачем они и какие проблемы решают.
Introduction to the transformers
Pre/Post-LayerNorm transformer
Сама архитектура имеет две вариации, оригинальная архитектура из attention is all you need является Post-LN, все LLM на данный момент без исключений являются Pre-LN. Вкратце первый перфомит на маленьких трансформерах (6 и меньше слоёв) и расходится при большом количестве слоёв, а второй обучается стабильнее.
class TransformerEncoderLayer(nn.Module):
def forward(self,
inputs:Tensor,
key_padding_mask:Optional[Tensor]=None,
attn_mask:Optional[Tensor]=None,
) -> Tensor:
# |inputs| : (batch_size, seq_len, d_model)
# |key_padding_mask| : (batch_size, seq_len)
# |attn_mask| : (batch_size, seq_len, seq_len)
if self.norm == 'pre':
# Forward inputs over first layernorm
x = self.layernorm1(inputs)
# Calculate self attention scores
attn_output, _ = self.attn(
query=x, key=x, value=x,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask
)
shortcut = inputs + attn_output # First schortcut
x = shortcut + self.ffn(self.layernorm2(x)) # Second schortcut
# |x| : (batch_size, seq_len, d_model)
elif self.norm == 'post':
attn_output, _ = self.attn(
query=inputs, key=inputs, value=inputs,
key_padding_mask=key_padding_mask, attn_mask=attn_mask
)
x = self.layernorm1(inputs + attn_output) # First schortcut
x = self.layernorm2(x + self.ffn(x)) # Second schortcut
# |x| : (batch_size, seq_len, d_model)
else: raise Exception
return xResidual connections
В каждом блоке трансформера используется два residual connection'а, они же shortcuts.
Layer normalization
Слой нормализации после каждого residual connection в Post-LN архитектуре и перед MHA, FFN в Pre-LN.
class LayerNorm(nn.Module):
def __init__(self, hidden_size:int, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
def forward(self, x: Tensor) -> Tensor:
# |x| : (batch_size, seq_len, d_model)
# Calculate the E[x] of all elements
mean = x.mean(dim=-1, keepdim=True)
# |mean| : (batch_size, seq_len, 1)
# Calculate the squared mean E[X^2] of all elements
mean_x2 = (x ** 2).mean(dim=-1, keepdim=True)
# Variance of all element Var[X] = E[X^2] - E[X]^2
var = mean_x2 - mean ** 2
# |var| : (batch_size, seq_len, 1)
# Normalize x. Layernorm[X] = weight * (X - E[X])/(Var[X] + eps)^0.5 + bias
x_norm = (x - mean) / torch.sqrt(var + self.eps)
# Elementwise affine transformation
x_norm = x_norm * self.weight + self.bias
# |x_norm| : (batch_size, seq_len, d_model)
return x_normLearning rate warm-up
Линейный прогрев LR с 0 до R первые t шагов оптимизатора,но как правило используется и learning rate decay, когда после шага t мы уменьшаем LR.
Adaptive optimizers
Трансформерам нужен adaptive оптимизатор, который накапливает квадраты градиентов, по типу Adam.
Scaled Dot Product
В трансформере мы измеряем похожесть токенов с помощью scaled dot product, это обычный dot product (скалярное произведение) деленный на корень из размерности векторов dk
Training Tips
Scaled Dot Product
Для больших значений dk, размерностей векторов, результат обычного dot product сильно растёт в величине, толкая softmax в регионы где он имеет экстремально низкие градиенты. Чтобы проиллюстрировать, почему скалярные произведения становятся большими, предположим, что компоненты векторов q и k являются независимыми случайными переменными со средним значением 0 и дисперсией 1. Тогда их скалярное произведение, q · k = Σᵢᵈ qᵢkᵢ, имеет среднее значение 0 и дисперсию d. Таким образом скейля результат на √d мы получаем матрицу скоров со стандартным отклонением равным 1 (при условии что компоненты векторов q и k являются независимыми случайными переменными со средним значением 0 и дисперсией 1)
Residual connections
Считая градиенты слой за слоем градиенты могут становится очень маленькими (или очень большими):
residual connections позволяет информации с более ранних слоёв легче пробираться до выхода, "перескакивать" через слои и таким образом борется с проблемой затухающих градиентов. Менее формально можно сказать что residual connection не трансформирует входной вектор x напрямую, а вносит модификацию во входной вектор x:
x = x + F(x)
Layer Normalization
Есть и обратная сторона медали в residual connection. В работе Zhang et al., 2019 показано, что выходная дисперсия residual connection растет экспоненциально с глубиной сети. Поэтому нам необходима нормализация для предотвращения градиентного взрыва для глубоких residual сетей. В трансформерах, для обработки последовательных данных, мы используем LayerNorm вместо BatchNorm, т.к. "batch statistics in transformers on NLP tasks have larger variations".
Learning rate warm-up
Нужен warmup из-за layernorm и Adam.
Выводы из статьи Xiong et al., 2020 про layer norm:
- LayerNorm скейлит градиенты
- В Post-LN скейлинг не зависит от количества слоёв трансформера L
- В Pre-LN градиенты параметров скейлятся на √L
- В Post-LN норма градиентов увеличиваются к слоям близким к выходу
- В Pre-LN норма градиентов остаётся та же, вне зависимости от слоя
Обучили Post-LN с разными параметрами:
Видно как сильно влияет выбор оптимизатор (об этом дальше), наличие warm-up и его продолжительность. Обучение не удастся с большим LR, хотя с довольно низким LR обучение пойдёт, но очень медленно:
Вообще это касается больше Post-LN чем Pre-LN. Как видно на [картинке 6] Pre-LN без warm-up сходится быстрее чем Post-LN с warm-up.
1. У Post-LN в начале тренировки огромные градиенты параметров рядом с residual connection (т.е. FFN), которые быстро затухают по мере обучения (предполагаю это из-за Attention Entropy Collapse, об этом далее).
2. Layernorm скейлит градиенты:
Liu et al. (2019) показали почему warmup нужен для стабилизации обучения с адаптивными оптимизаторами: RMSprop, Adam. Проблема в высокой дисперсии в начале обучения adaptive learning rate'ов (Beta 2 параметр в Adam), фиксится через warmup, низкий LR, RAdam, или первые N (около 2000) шагов не обновлять веса модели и моментум, обновлять только лернинг рейты, после чего warm-up больше не нужен (не нашёл в статье про Post или Pre LN идёт речь).
Huang et al., 2020 показали что эффект комбинируется, из-за высокой дисперсии градиентов в начале получаются большие обновления весов, что увеличивает норму эмбедингов поступающих в layernorm. d модели фиксированная и √d/||x|| начинает очень быстро становится меньше 1 что приводит к затухающим градиентам.
Adaptive optimizers
SGD на практике не работает с трансформерами [картинка 4] из-за unbalanced gradients: дифференцируясь через MHA градиенты value Wv сильно больше query Wq и key Wk. Adam фиксит эту проблему, т.к. хранит квадраты градиентов и использует их как learning rate для каждого параметра. Но и вызывает проблему описанную выше
Напоминаю о том как выглядит MHA (с одной головой внимания):
class SelfAttention(nn.Module):
def __init__(self, input_dim, hidden_dim: Optional[int] = None, dropout: float = 0.):
super(SelfAttention, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.dropout = nn.Dropout(dropout)
self.sqrt = torch.sqrt(torch.tensor(hidden_dim, dtype=torch.float32))
self.Wq = nn.Linear(input_dim, hidden_dim)
self.Wk = nn.Linear(input_dim, hidden_dim)
self.Wv = nn.Linear(input_dim, hidden_dim)
def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""
Args:
x (torch.Tensor): Input sequence of shape (batch_size, seq_len, input_dim).
key_padding_mask (torch.Tensor): Padding mask of shape (batch_size, seq_len),
where 1 indicates padding positions and 0 indicates non-padding positions.
Defaults to None.
attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
where 1 indicates positions to attend and 0 indicates positions to ignore.
Defaults to None.
Returns:
torch.Tensor: Output sequence of shape (batch_size, seq_len, out_dim).
torch.Tensor: Attention weights of shape (batch_size, seq_len, seq_len).
"""
q: Tensor = self.Wq(x) # [batch_size, seq_len, hidden_dim]
k: Tensor = self.Wk(x) # [batch_size, seq_len, hidden_dim]
v: Tensor = self.Wv(x) # [batch_size, seq_len, hidden_dim]
# Calculate scores: QK.T / sqrt
scores = torch.bmm(q, k.transpose(1, 2)) / self.sqrt
# |scores| - (batch_size, seq_len, seq_len)
# Apply padding and attention masks
scores = self.apply_masks(scores, key_padding_mask, attention_mask)
# Apply softmax to scores: softmax(QK.T/sqrt)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
# Multiply attention weights with value vectors to get context vectors
context_vectors = torch.bmm(attention_weights, v) # softmax(QK.T/sqrt)V
return context_vectors, attention_weights
def apply_masks(self, scores: Tensor, key_padding_mask: Optional[Tensor] = None,
attention_mask: Optional[Tensor] = None) -> Tensor:
...Amplification Effect
Liu et al. (2020) показали что небольшие изменение в параметрах ведут к большим изменениям на выходе из трансформера (из-за MHA), назвали это Amplification Effect. В Pre-LN зависимость этих флуктуаций от количества слоёв N обладает следующим свойством O(lon N), в Post-LN - O(N). Так же показали что Post-LN сильно зависим от residual branches, т.е. LayerNorm[X + MHA[X]] зависит больше от MHA[X] чем от X, в то время Как Pre-LN по ходу обучения учится зависеть от residual branches, отсюда и разница в Amplification Effect и стабильности обучения.
В статье они предлагают технику Admin что бы пофиксить Post-LN, суть метода в том что бы добавить обучаемый параметр Ψ размерностью 1 x D, D - размерность модели, таким образом первый LayerNorm выглядит следующим образом: LayerNorm[X ⊙ Ψ + MHA[X]], ⊙ - покомпонентное произведение. Для Post-LN инициализируется единицами. Для Pre-LN значением omega_value :
omega_value = (num_res_layers + 1) / math.log(num_res_layers + 1) - 1
где num_res_layers - 2 * количество слоёв трансформера.
Есть и метод ReZero, то же самое что и Admin, но вместо вектора, обучаемый скаляр a, инициализируемый нулём, архитектура выглядит так: X + aMHA[X]
Есть метод DeepNorm, который позволяет скейлить Post-LN до 1000 слоёв, для этого определённым спобом инициализируют веса и модернизируют второй LayerNorm, вводя константу a:
def deepnorm_init(w):
if w is ['ffn', 'v_proj', 'out_proj']:
nn.init.xavier_normal_(w, gain=β)
elif w is ['q_proj', 'k_proj']:
nn.init.xavier_normal_(w, gain=1)Значения a, β высчитываются следующим образом:
Attention Entropy Collapse
Интересная работа apple, показали что стабильность обучения трансформеров тесно связана с энтропией атеншена:
Показали, что низкая энтропия ведёт к расхождению и нестабильности тренировки. Пример тензора с высокой и низкой энтропией:
>>> tensor = torch.rand(4, 4)
>>> tensor1 = (tensor/100).softmax(dim=-1)
>>> print(tensor1)
tensor([[0.2499, 0.2502, 0.2499, 0.2500],
[0.2513, 0.2492, 0.2489, 0.2507],
[0.2503, 0.2507, 0.2499, 0.2492],
[0.2510, 0.2505, 0.2488, 0.2497]])
>>> attention_entropy = -(tensor1 * torch.log(tensor1)).mean()
>>> print(attention_entropy) # High Entropy
tensor(0.3466)
>>> tensor2 = (tensor/0.01).softmax(dim=-1)
>>> print(tensor2.round(decimals=4))
tensor([[0.0000, 0.9999, 0.0000, 0.0000],
[1.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000, 0.0000],
[1.0000, 0.0000, 0.0000, 0.0000]])
>>> attention_entropy = -(tensor2 * torch.log(tensor2)).mean()
>>> print(attention_entropy) # Low Entropy, close to 0
tensor(3.5994e-05)Высокий Learning Rate ведёт к attention entropy collapse (энтропия падает до 0):
Так же энтропию можно контролировать введя temperature параметр перед softmax:
attention_weights = F.softmax(scores/self.temperature, dim=-1)
Показали что уменьшая температуру, с 1 до 0.1, а тем самым уменьшая энтропию, на 10 и на 50 эпохе, трансформер расходится:
Нижняя граница энтропия выражается через Tσe^(-σ), где T - длина последовательности, X - входной вектор, Wk, Wq - матрицы key и query MHA
Предлагают σReparam, та же спектральная нормализация с обучаемым коэффициентом γ
σ(W) предлагают считать через power iteration:
# Parameters. W: weight matrix, shape (d, c); gamma: the learned spectral norm, shape (1,)
# Buffers. u: shape (d,), v: shape (c,), the left and right singular vectors of W
if init: # initialize u, v as random unit vectors and gamma to 1
u = randn(d)
u = u / u.norm(dim=0)
v = randn(c)
v = v / v.norm(dim=0)
gamma = ones(1)
if training: # if in the training mode, perform one step of power iteration first
with torch.no_grad():
u = W.mv(v)
u = u / u.norm(dim=0)
v = W.T.mv(u)
v = v / v.norm(dim=0)
sigma = einsum(’d,dc,c->’, u, W, v)
W_hat = gamma / sigma * W # the effective spectral norm of W_hat would be gammaσReparam применяют для всех весов линейных и свёрточных слоёв, убирая при этом LayerNorm