Snap Diffusion
Text-to-Image Diffusion Model on Mobile Devices within Two Seconds
Все мы любим генерировать причудливые картинки с помощью диффузионных генеративных моделей: Midjourney, Stable Diffusion, Шедеврум, Kandinsky и многих других. Нейросети, обученные на огромных количествах пар текстов и картинок способны удовлеторить почти любую, даже самую безумную фантазию заказчика (если не считать правильного количества пальцев, взаимного расположения обьектов и иных нюансов).
Однако фундаментальной проблемой диффузионных моделей является скорость их работы. В отличие от GANов, которые берут шум и сразу выдают картинку (возможно прогнав через SuperResolution каскад), при генерации диффузионными моделями приходится много раз применять довольно тяжеловесный UNet. Потому даже на хорошей GPU процесс генерации происходит не по щелчку пальца, что уже говорить при попытке инференса на мобильных устройствах.
В своем желании ускорить диффузионки человечество исследовало различные направления:
- Продвинутые samplerы, способные генерировать картинки хорошего качества за малое число шагов
- Архитектурные оптимизации
- Пошаговая дистилляция, когда модель-ученик пытается предсказать за один раз несколько последовательных шагов учителя.
В данной работе команда из Snap и Cеверо-Восточного Университета скомбинировала несколько известных подходов для ускорения диффузионных моделей и смогла добиться генерации картинок менее чем за 2с на iPhone 14.
Анализ
Для того, чтобы понять, куда копать авторы статьи замерили время работы различных компонент. Как и следовало ожидать, основное время работы уходит на UNet, применяемый много раз. VAE Decoder, погружащий картинку в латентное пространство, в исходной пайплайне занимает пренебрежимо мало времени, но в отпимизированной версии тоже надлежит отпимизации. Текстовый энкодер работает за несколько миллисекунд и не требует никакой оптимизации.
Большинство параметров находятся в центральной части UNet, что неудивительно, так как данная архитектура уменьшает постепенно разрешение и увеличивает число каналов, однако большая часть вычислинений приходится на первые и последние блоки, причем на модули Cross-Attention между текстовыми токенами и токенами изображения.
Архитектурные оптимизации
Как утверждают авторы, прунинг отдельных компонент или их поиск (не совсем понятно, что имеется в виду - поиск в пространстве возможных операций или conditional sparsity?) приводит к серьезному ухудшению качества изображений. Потому они предлагают вариант обучения, изначально устойчивый к изменениям архитектуры - elastic / stochastic depth, где каждый блок выполняется или не выполняется с заданной вероятностью (residual структура ResNet и Cross-Attention позволяет это сделать).
Далее компоненты ранжируются по метрике полезности, определяемой как отношение изменения CLIP Score на Latency (время исполнения блока). Для каждого замера берется подвыборка из 2k картинок из MS-COCO, что занимает 2.5 часа на одной A100. Далее сортируем компоненты по данной метрике и сохраняем самые полезные, чтобы удовлетворить заданному ограничению на время T.
В итоге оказалось полезным избавляться от дорогостоящего Cross-Attention на ранних и поздних стадиях UNet (с наибольшим пространственным разрешением) и уменьшать количество ResNet блоков в глубине сети.
Оптимизация VAE-decoder менее интересна. Если я правильно понял, то они берут decoder как в исходной модели Stable Diffusion v1.5, но меньшего размера, и дистиллируют его на предсказания декодера SD v1.5. Качество при этом не проседает.
Пошаговая дистилляция
Примечателен способ дистилляции. Следуя прошлым работам, авторы делают пошаговую дистилляцию не по шуму eps
, а скорости v
.
В предложенной схеме, ученик пытается за один шаг предсказать два шага учителя.
Как оказалось, подобная стратегия позволяет сохранить качество генерации по метрике FID, но приводит к существенную ухудшению CLIP score. Потому было предложено использовать classifier-guidance aware версию скорости, что значительно улучшило CLIP score (меру соответствия картинки тексту).
Дистилляция проводится поэтапно, весьма хитрым способом:
- Stable Diffusion 32 шага -> Stable Diffusion 16 шагов
- Stable Diffusion 16 шагов -> Snap Diffusion 16 шагов
- Snap Diffusion, с шага 2, дистиллируется на Stable Diffusion 16 шагов, для получения Snap Diffusion 8 шагов
Эксперименты
Обучалась вся эта красота на 16-32 нодах с 8 A100.
Предложенная схема дистилляция показала себя лучше, чем быстрые солверы.
Итоговая сеть показывает себя не хуже SD v1.5, работая на порядок быстрее при этом. Оптимум по FID и CLIP score, однако достигается у Snap Diffusion при других значениях guidance. На графике справа ниже авторы демонстрируют, что их стратегия дистилляции лучше возможных альтернатив.
Из ablation - авторы проверяют, что сеть устойчива к вырыванию Cross-Attention и ResNet блоков в отличие от исходной сети, которая ломается.
Итого, довольно сильный и практически полезный результат.
Как мне кажется, в качестве альтернативы предложенной stochastic-depth стратегии можно было бы использовать DARTS и выбирать из нескольких вариантов модулей различной стоимости.