Mamba - Linear-Time Sequence Modeling with Selective State Spaces
Автор FLASH ATTN ебанутый, вы знали? Чувак в одно лицо поднял top1 решение для ускорения практически любой сетки, а теперь сделал сетку по мотивам, ключевое:
- сделали аппаратное ускорение(что еще ждать от Kernels enjoyer)
- выкинули attn
- выкинули MLP(sic!)
- понапихахи CONVов потому что они используют LOG(N) на выбор ответа(позже)
- Суется 1м токенов в контекст
- х5 быстрее учиться чем трансформер
- Линейная сложность на длинну контекста
- Обучаемый gating
Принципиально архитектура выглядит так:
Ну в целом понятно, докинули какой то SSM, а ssm выглядит в свою очередь так:
Да кто такой этот ваш SSM
Идея такая - attn заебись, но он не умеет сжимать контекст, те он по честному смотрит на весь контекст сразу - а это какой то долбоебизм, это ОЧЕНЬ сложно поскейлить на 100500 токенов в контексте, а значит BItter Lesson.
Авторы предлагают следующее решение: давайте заставим с помощью CONV выбирать с какого токена мы смотрим на последовательность(офк CONV обучаемый), а затем эту последовательность пропускать через обычную рекурентность(рил умно и просто, да?)
По сути авторы вводят обучаемый фильтр контекста - модель сама выбирает на что смотреть при генерации N токена!
Механизм SSM выглядит следующим образом:
СКОРОСТЬ
Модель ебет на длинных контекстах, и чем длинее конекст и больше батч - тем сильнее ебет, ее можно ставить в 10000rps
Evaluation + scaling laws
Все довольно круто, мамба лучше трансформера на одинаковых FLOPs
На бенчмарках все сильно лучше, разница примерно в 15-20% при одинаковом размере
ЛЛамы тут очевидно нет, лламу учить тупо дорого(2т токенов против 300b в ThePile)