December 6, 2023

Mamba -  Linear-Time Sequence Modeling with Selective State Spaces

Автор FLASH ATTN ебанутый, вы знали? Чувак в одно лицо поднял top1 решение для ускорения практически любой сетки, а теперь сделал сетку по мотивам, ключевое:

  • сделали аппаратное ускорение(что еще ждать от Kernels enjoyer)
  • выкинули attn
  • выкинули MLP(sic!)
  • понапихахи CONVов потому что они используют LOG(N) на выбор ответа(позже)
  • Суется 1м токенов в контекст
  • х5 быстрее учиться чем трансформер
  • Линейная сложность на длинну контекста
  • Обучаемый gating

Принципиально архитектура выглядит так:

ну структура блока в цело понятная, да? H3 это старая selective space model, лень про нее писать

Ну в целом понятно, докинули какой то SSM, а ssm выглядит в свою очередь так:

Да кто такой этот ваш SSM

Идея такая - attn заебись, но он не умеет сжимать контекст, те он по честному смотрит на весь контекст сразу - а это какой то долбоебизм, это ОЧЕНЬ сложно поскейлить на 100500 токенов в контексте, а значит BItter Lesson.

Авторы предлагают следующее решение: давайте заставим с помощью CONV выбирать с какого токена мы смотрим на последовательность(офк CONV обучаемый), а затем эту последовательность пропускать через обычную рекурентность(рил умно и просто, да?)

По сути авторы вводят обучаемый фильтр контекста - модель сама выбирает на что смотреть при генерации N токена!

не иронично умное копирование(Selection+Copy)



Механизм SSM выглядит следующим образом:

СКОРОСТЬ

Модель ебет на длинных контекстах, и чем длинее конекст и больше батч - тем сильнее ебет, ее можно ставить в 10000rps

Evaluation + scaling laws

Все довольно круто, мамба лучше трансформера на одинаковых FLOPs

На бенчмарках все сильно лучше, разница примерно в 15-20% при одинаковом размере

ЛЛамы тут очевидно нет, лламу учить тупо дорого(2т токенов против 300b в ThePile)

А еще она скейлиться до 1м токенов в контексте