<?xml version="1.0" encoding="utf-8" ?><rss version="2.0" xmlns:tt="http://teletype.in/" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:content="http://purl.org/rss/1.0/modules/content/" xmlns:media="http://search.yahoo.com/mrss/"><channel><title>Victor</title><generator>teletype.in</generator><description><![CDATA[Victor]]></description><link>https://teletype.in/@avenida?utm_source=teletype&amp;utm_medium=feed_rss&amp;utm_campaign=avenida</link><atom:link rel="self" type="application/rss+xml" href="https://teletype.in/rss/avenida?offset=0"></atom:link><atom:link rel="next" type="application/rss+xml" href="https://teletype.in/rss/avenida?offset=10"></atom:link><atom:link rel="search" type="application/opensearchdescription+xml" title="Teletype" href="https://teletype.in/opensearch.xml"></atom:link><pubDate>Thu, 16 Apr 2026 11:46:35 GMT</pubDate><lastBuildDate>Thu, 16 Apr 2026 11:46:35 GMT</lastBuildDate><item><guid isPermaLink="true">https://teletype.in/@avenida/bGx_rXp-zSv</guid><link>https://teletype.in/@avenida/bGx_rXp-zSv?utm_source=teletype&amp;utm_medium=feed_rss&amp;utm_campaign=avenida</link><comments>https://teletype.in/@avenida/bGx_rXp-zSv?utm_source=teletype&amp;utm_medium=feed_rss&amp;utm_campaign=avenida#comments</comments><dc:creator>avenida</dc:creator><title>16-, 8- и 4-битные форматы чисел с плавающей запятой</title><pubDate>Fri, 08 Dec 2023 08:40:25 GMT</pubDate><media:content medium="image" url="https://img4.teletype.in/files/f5/8a/f58ac37f-025a-4336-93e0-5ab3d0f7965b.png"></media:content><description><![CDATA[<img src="https://habrastorage.org/r/w1560/getpro/habr/upload_files/c3f/bad/c83/c3fbadc838abfa6ca95c601ca774109d.png"></img>https://habr.com/ru/companies/wunderfund/articles/776496/]]></description><content:encoded><![CDATA[
  <p id="KrX8"><a href="https://habr.com/ru/companies/wunderfund/articles/776496/" target="_blank">https://habr.com/ru/companies/wunderfund/articles/776496/</a></p>
  <p id="cT7b">Средний</p>
  <p id="XBGx">15 мин</p>
  <p id="nzAq">16K <a href="https://habr.com/ru/companies/wunderfund/articles/" target="_blank">Блог компании Wunder Fund</a><a href="https://habr.com/ru/hubs/webdev/" target="_blank">Веб-разработка*</a><a href="https://habr.com/ru/hubs/python/" target="_blank">Python*</a><a href="https://habr.com/ru/hubs/programming/" target="_blank">Программирование*</a></p>
  <p id="H0Mo">Уже лет 50, со времён выхода первого издания «Языка программирования Си» Кернигана и Ритчи, известно, что «числа с плавающей запятой» одинарной точности имеют размер 32 бита, а числа двойной точности — 64 бита. Существуют ещё и 80-битные числа расширенной точности типа «long double». Эти типы данных покрывали почти все нужды обработки вещественных чисел. Но в последние несколько лет, с наступлением эпохи больших нейросетевых моделей, у разработчиков появилась потребность в типах данных, которые не «больше», а «меньше» существующих, потребность в том, чтобы как можно сильнее «сжать» типы данных, представляющие числа с плавающей запятой.</p>
  <figure id="khPs" class="m_custom">
    <img src="https://habrastorage.org/r/w1560/getpro/habr/upload_files/c3f/bad/c83/c3fbadc838abfa6ca95c601ca774109d.png" width="780" />
  </figure>
  <p id="EpzC">Я, честно говоря, был удивлён, когда узнал о существовании 4-битного формата для представления чисел с плавающей запятой. Да как такое вообще возможно? Лучший способ узнать об этом — самостоятельно поработать с такими числами. Сейчас мы исследуем самые популярные форматы чисел с плавающей запятой, создадим с использованием некоторых из них простую нейронную сеть и понаблюдаем за тем, как она работает.</p>
  <h2 id="7Zwc">«Стандартные» 32-битные числа с плавающей запятой</h2>
  <p id="KJcv">Прежде чем переходить к описанию «экстремальных» типов данных — давайте вспомним о стандартном типе. Стандарт <a href="https://en.wikipedia.org/wiki/IEEE_754" target="_blank">IEEE 754</a>, регламентирующий арифметику с плавающей запятой, был принят в 1985 году Институтом инженеров электротехники и электроники (Institute of Electrical and Electronics Engineers, IEEE). Типичное 32-битное число с плавающей запятой, в соответствии с этим стандартном, выглядит так:</p>
  <figure id="ygse" class="m_custom">
    <img src="https://habrastorage.org/r/w1560/getpro/habr/upload_files/6ab/49e/1a6/6ab49e1a68d008f53d49cec5e3d7ee06.png" width="700" />
    <figcaption>Пример 32-битного числа с плавающей запятой (<a href="https://en.wikipedia.org/wiki/IEEE_754" target="_blank">источник</a>)</figcaption>
  </figure>
  <p id="Fxvm">Первый бит задаёт знак числа, следующие 8 битов представляют порядок, а остальные биты — мантиссу. Десятичное значение числа находят по следующей формуле:</p>
  <figure id="tKqR" class="m_custom">
    <img src="https://habrastorage.org/r/w1560/getpro/habr/upload_files/699/26b/ccc/69926bcccaf2ada34c1e2830eee78b1a.png" width="370" />
    <figcaption>Формула для нахождения десятичного значения двоичного числа с плавающей запятой (<a href="https://en.wikipedia.org/wiki/Floating-point_arithmetic" target="_blank">источник</a>)</figcaption>
  </figure>
  <p id="sCNn">Вот — простая вспомогательная функция, которая позволит нам выводить на экран числа с плавающей запятой в их двоичном виде:</p>
  <pre id="EvLD">import struct

def print_float32(val: float):
    &quot;&quot;&quot; Print Float32 in a binary form &quot;&quot;&quot;
    m = struct.unpack(&#x27;I&#x27;, struct.pack(&#x27;f&#x27;, val))[0]
    return format(m, &#x27;b&#x27;).zfill(32)

print_float32(0.15625)

# &gt; 00111110001000000000000000000000</pre>
  <p id="Btkz">Напишем ещё одну вспомогательную функцию, которая позволяет выполнять обратное преобразование. Позже она нам пригодится:</p>
  <pre id="QQao">def ieee_754_conversion(sign, exponent_raw, mantissa, exp_len=8, mant_len=23):
    &quot;&quot;&quot; Convert binary data into the floating point value &quot;&quot;&quot;
    sign_mult = -1 if sign == 1 else 1
    exponent = exponent_raw - (2 ** (exp_len - 1) - 1)
    mant_mult = 1
    for b in range(mant_len - 1, -1, -1):
        if mantissa &amp; (2 ** b):
            mant_mult += 1 / (2 ** (mant_len - b))

    return sign_mult * (2 ** exponent) * mant_mult


ieee_754_conversion(0b0, 0b01111100, 0b01000000000000000000000)

#&gt; 0.15625</pre>
  <p id="T4RO">И я надеюсь, что все программисты и IT‑энтузиасты знают, что точность чисел с плавающей запятой ограничена:</p>
  <pre id="nKjK">val = 3.14
print(f&quot;{val:.20f}&quot;)

# &gt; 3.14000000000000012434</pre>
  <p id="MyrH">Это, в данном случае, не такая уж и проблема. Но, чем меньше у нас бит, тем меньше точность, на которую можно рассчитывать. И, как мы скоро увидим, точность вполне может быть проблемой. А теперь — начнём путешествие по кроличьей норе…</p>
  <h2 id="UMVs">16-битные числа с плавающей запятой</h2>
  <p id="YbDI">Очевидно, раньше особой потребности в 16-битных числах с плавающей запятой не было, поэтому описание соответствующего типа было добавлено в стандарт IEEE 754 только в 2008 году. У таких чисел имеется знаковый бит, 5-битный порядок и 10-битная мантисса:</p>
  <figure id="1gBQ" class="m_custom">
    <img src="https://habrastorage.org/r/w1560/getpro/habr/upload_files/d2d/679/d61/d2d679d618d217e44262a8b6dd50cdc1.png" width="609" />
    <figcaption>Пример 16-битного числа с плавающей запятой (<a href="https://en.wikipedia.org/wiki/Bfloat16_floating-point_format" target="_blank">источник</a>)</figcaption>
  </figure>
  <p id="PEkt">Логика преобразования десятичных представлений таких чисел в двоичные точно такая же, как и при работе с 32-битными числами, но их точность, безусловно, ниже, чем у 32-битных чисел. Выведем 16-битное число с плавающей запятой в двоичном виде:</p>
  <pre id="gZmZ">import numpy as np

def print_float16(val: float):
    &quot;&quot;&quot; Print Float16 in a binary form &quot;&quot;&quot;
    m = struct.unpack(&#x27;H&#x27;, struct.pack(&#x27;e&#x27;, np.float16(val)))[0]
    return format(m, &#x27;b&#x27;).zfill(16)

print_float16(3.14)

# &gt; 0100001001001000</pre>
  <p id="JzvE">Прибегнув к методу, которым мы уже пользовались, можем выполнить обратное преобразование:</p>
  <pre id="cERq">ieee_754_conversion(0, 0b10000, 0b1001001000, exp_len=5, mant_len=10)

# &gt; 3.140625</pre>
  <p id="KuTs">А вот как можно найти максимальное значение, представимое в виде числа типа <code>float16</code>:</p>
  <pre id="d0tG">ieee_754_conversion(0, 0b11110, 0b1111111111, exp_len=5, mant_len=10)

#&gt; 65504.0</pre>
  <p id="J4zc">Я использовал тут <code>0b11110</code> из-за того, что в стандарте IEEE 754 число <code>0b11111</code> зарезервировано для «бесконечности». Можно найти и возможное минимальное значение:</p>
  <pre id="ibGr">ieee_754_conversion(0, 0b00001, 0b0000000000, exp_len=5, mant_len=10)

#&gt; 0.00006104</pre>
  <p id="qyKr">Для большинства разработчиков типы, вроде описанного — это «неизведанная территория». И, судя по всему, даже в наши дни в C++ нет стандартного 16-битного типа данных для чисел с плавающей запятой. Но разнообразие типов этим не ограничивается.</p>
  <h2 id="bjuT">16-битные числа с плавающей запятой «bfloat» (BFP16)</h2>
  <p id="avS3">Этот формат чисел с плавающей запятой разработан командой Google Brain. Он спроектирован специально для нужд машинного обучения (буква «B» в его названии — это сокращение от «brain»). Это — модификация «стандартного» 16-битного формата: порядок увеличен до 8 бит, в результате диапазон значений <code>bfloat16</code>, на самом деле, получается таким же, как у <code>float32</code>. Но размер мантиссы был уменьшен до 7 бит:</p>
  <figure id="DRBI" class="m_custom">
    <img src="https://habrastorage.org/r/w1560/getpro/habr/upload_files/dfb/7f4/b53/dfb7f4b531f5bf6f02aaf8299026c618.png" width="609" />
    <figcaption>Пример 16-битного числа с плавающей запятой bfloat16 (<a href="https://en.wikipedia.org/wiki/Bfloat16_floating-point_format" target="_blank">источник</a>)</figcaption>
  </figure>
  <p id="IS7i">Проведём небольшой эксперимент, аналогичный предыдущим:</p>
  <pre id="wYj0">ieee_754_conversion(0, 0b10000000, 0b1001001, exp_len=8, mant_len=7)

#&gt; 3.140625</pre>
  <p id="NUtB">Как уже было сказано — из‑за увеличенного порядка формат <code>bfloat16</code> вмещает в себя гораздо больший диапазон значений, чем <code>float16</code>:</p>
  <pre id="bJ86">ieee_754_conversion(0, 0b11111110, 0b1111111, exp_len=8, mant_len=7)

#&gt; 3.3895313892515355e+38</pre>
  <p id="4tgD">Это — гораздо лучше в сравнении с <code>65504.0</code> из предыдущего примера, но, как уже было сказано, точность чисел <code>bfloat16</code> ниже из‑за того, что на мантиссу приходится меньшее число бит. Можно протестировать оба типа в TensorFlow:</p>
  <pre id="YS7Z">import tensorflow as tf

print(f&quot;{tf.constant(1.2, dtype=tf.float16).numpy().item():.12f}&quot;)

# &gt; 1.200195312500

print(f&quot;{tf.constant(1.2, dtype=tf.bfloat16).numpy().item():.12f}&quot;)

# &gt; 1.203125000000</pre>
  <h2 id="eNoW">8-битные числа с плавающей запятой (FP8)</h2>
  <p id="l0gD">Этот (сравнительно новый) формат был предложен в 2022 году и, как может догадаться читатель, он тоже создан для целей машинного обучения. Модели становятся всё больше и больше, их всё сложнее и сложнее умещать в памяти GPU. Формат FP8 существует в двух вариантах: E4M3 (4-битный порядок и 3-битная мантисса) и E5M2 (5-битный порядок и 2-битная мантисса):</p>
  <figure id="ORgd" class="m_custom">
    <img src="https://habrastorage.org/r/w1560/getpro/habr/upload_files/0b1/dbb/d22/0b1dbbd22150843b54a5b1ec4562f9e2.png" width="700" />
    <figcaption>Пример 8-битных чисел с плавающей запятой (<a href="https://en.wikipedia.org/wiki/Minifloat" target="_blank">источник</a>)</figcaption>
  </figure>
  <p id="I3tU">Выясним максимально возможные значения чисел для обоих вариантов FP8:</p>
  <pre id="NyiC">ieee_754_conversion(0, 0b1111, 0b110, exp_len=4, mant_len=3)

# &gt; 448.0

ieee_754_conversion(0, 0b11110, 0b11, exp_len=5, mant_len=2)

# &gt; 57344.0</pre>
  <p id="SkI5">Формат FP8 можно использовать и в TensorFlow:</p>
  <pre id="AclB">import tensorflow as tf
from tensorflow.python.framework import dtypes


a_fp8 = tf.constant(3.14, dtype=dtypes.float8_e4m3fn)
print(a_fp8)

# &gt; 3.25

a_fp8 = tf.constant(3.14, dtype=dtypes.float8_e5m2)
print(a_fp8)

# &gt; 3.0</pre>
  <p id="HfBj">Нарисуем график синуса, используя оба типа:</p>
  <pre id="v1ma">import numpy as np
import tensorflow as tf
from tensorflow.python.framework import dtypes
import matplotlib.pyplot as plt

length = np.pi * 4
resolution = 200
xvals = np.arange(0, length, length / resolution)
wave = np.sin(xvals)
wave_fp8_1 = tf.cast(wave, dtypes.float8_e4m3fn)
wave_fp8_2 = tf.cast(wave, dtypes.float8_e5m2)

plt.rcParams[&quot;figure.figsize&quot;] = (14, 5)
plt.plot(xvals, wave_fp8_1.numpy())
plt.plot(xvals, wave_fp8_2.numpy())
plt.show()</pre>
  <p id="jvL0">Результат, что удивительно, не так уж и плох:</p>
  <figure id="Ke8K" class="m_custom">
    <img src="https://habrastorage.org/r/w1560/getpro/habr/upload_files/2ce/3eb/306/2ce3eb3061c1fd7b6091a5a94a671963.png" width="700" />
    <figcaption>Синусоидальная волна, построенная по данным, представленным в разных вариантах формата FP8 (изображение подготовлено автором)</figcaption>
  </figure>
  <p id="a2uA">Тут ясно видны некоторые потери точности, но то, что получилось, очень даже похоже на синусоиду!</p>
  <h2 id="IfM9">4-битные числа с плавающей запятой (FP4, NF4)</h2>
  <p id="0MXv">А теперь перейдём к самой «безумной» теме — к 4-битным числам с плавающей запятой (FP4). На самом деле такие числа — это самые компактные значения с плавающей запятой, соответствующие стандарту IEEE, имеющие 1 бит на знак, 2 бита на порядок и 1 бит на мантиссу:</p>
  <figure id="S46A" class="m_custom">
    <img src="https://habrastorage.org/r/w1560/getpro/habr/upload_files/bf3/d1f/be5/bf3d1fbe54f3e486040f46e75c87e666.png" width="700" />
    <figcaption>Пример значения FP4 (изображение подготовлено автором)</figcaption>
  </figure>
  <p id="CjyI">Количество значений, которые можно сохранить в формате FP4, невелико. Все эти значения, на самом деле, помещаются в массив на 16 элементов!</p>
  <p id="DdVn">Ещё одна возможная реализация 4-битных чисел с плавающей запятой представлена типом данных, называемым NormalFloat (NF4). Значения NF4 оптимизированы для сохранения нормально распределённых данных. Все возможные значения NF4 легко вывести на экран в виде небольшого списка (при исследовании других типов данных это может оказаться совсем непростой задачей):</p>
  <pre id="SesK">[-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, 
 -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0,
  0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 
  0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0]</pre>
  <p id="IjYN">И тип FP4, и тип NF4 реализованы в Python‑библиотеке <a href="https://github.com/TimDettmers/bitsandbytes" target="_blank">bitsandbytes</a>. Давайте, в качестве примера, преобразуем массив <code>[1.0, 2.0, 3.0, 4.0]</code> в формат FP4:</p>
  <pre id="HjHY">from bitsandbytes import functional as bf

def print_uint(val: int, n_digits=8) -&gt; str:
    &quot;&quot;&quot; Convert 42 =&gt; �&#x27; &quot;&quot;&quot;
    return format(val, &#x27;b&#x27;).zfill(n_digits)

device = torch.device(&quot;cuda&quot;)
x = torch.tensor([1.0, 2.0, 3.0, 4.0], device=device)
x_4bit, qstate = bf.quantize_fp4(x, blocksize=64)

print(x_4bit)
# &gt; tensor([[117], [35]], dtype=torch.uint8)

print_uint(x_4bit[0].item())
# &gt; 01110101
print_uint(x_4bit[1].item())
# &gt; 00100011

print(qstate)
# &gt; (tensor([4.]), 
# &gt;  &#x27;fp4&#x27;, 
# &gt;  tensor([ 0.0000,  0.0052,  0.6667,  1.0000,  0.3333,  0.5000,  0.1667,  0.2500,
# &gt;           0.0000, -0.0052, -0.6667, -1.0000, -0.3333, -0.5000, -0.1667, -0.2500])])</pre>
  <p id="spN5">Результат выглядит интересно. На выходе получилось два объекта: 16-битный массив <code>[117, 35]</code>, содержащий наши 4 числа, и объект «состояния», в котором находятся коэффициент масштабирования 4.0 и тензор со всеми шестнадцатью FP4-числами.</p>
  <p id="gist">Например, первое 4-битное число — это «0111» (=7). В объекте состояния можно видеть, что соответствующее ему значение с плавающей запятой — это 0.25; 0.25*4 = 1.0. Второе число — это «0101» (=5), а результирующее значение — 0.5*4 = 2.0. Третье число — это «0010», которое равняется 2, а соответствующее ему значение — 0.666*4 = 2.666, которое достаточно близко к 3, но не равно этому числу. Понятно, что при применении 4-битных значений мы столкнёмся с некоторой потерей точности. Последнее значение, «0011» — это 3, ему соответствует 1.000*4 = 4.0.</p>
  <p id="gHYe">Понятно, что нет большой необходимости выполнять подобные вычисления вручную. С помощью <code>bitsandbytes</code> можно выполнить и обратное преобразование:</p>
  <pre id="03xy">x = bf.dequantize_fp4(x_4bit, qstate)
print(x)

# &gt; tensor([1.000, 2.000, 2.666, 4.000])</pre>
  <p id="o8jp">4-битный формат чисел тоже обладает ограниченным диапазоном значений. Например, массив <code>[1.0, 2.0, 3.0, 64.0]</code> будет преобразован в <code>[0.333, 0.333, 0.333, 64.0]</code>. Но для более или менее нормализованных данных он даёт совсем неплохие результаты. Давайте, для примера, нарисуем синусоиду, воспользовавшись данными в формате FP4:</p>
  <pre id="CMEJ">import matplotlib.pyplot as plt
import numpy as np
from bitsandbytes import functional as bf

length = np.pi * 4
resolution = 256
xvals = np.arange(0, length, length / resolution)
wave = np.sin(xvals)

x_4bit, qstate = bf.quantize_fp4(torch.tensor(wave, dtype=torch.float32, device=device), blocksize=64)
dq = bf.dequantize_fp4(x_4bit, qstate)

plt.rcParams[&quot;figure.figsize&quot;] = (14, 5)
plt.title(&#x27;FP8 Sine Wave&#x27;)
plt.plot(xvals, wave)
plt.plot(xvals, dq.cpu().numpy())
plt.show()</pre>
  <p id="xD7m">Тут, что неудивительно, видны некоторые потери точности, но то, что получилось, выглядит довольно прилично.</p>
  <figure id="Llyp" class="m_custom">
    <img src="https://habrastorage.org/r/w1560/getpro/habr/upload_files/8a6/1a5/42c/8a61a542c3ed08a4a7ed1e9d1ccf8617.png" width="700" />
    <figcaption>Синусоидальная волна, построенная по данным, представленным в формате FP4 (изображение подготовлено автором)</figcaption>
  </figure>
  <p id="jwHw">Если же говорить о типе NF4 — читатели сами могут попробовать исследовать его с помощью методов <code>quantize_nf4</code> и <code>dequantize_nf4</code>; весь код останется таким же, как прежде. Но, к сожалению, на момент написания этой статьи 4-битные типы данных работают лишь с CUDA; вычисления на CPU пока не поддерживаются.</p>
  <h2 id="E2b2">Тестирование</h2>
  <p id="55C0">Теперь, в роли финального этапа этой статьи, предлагаю создать нейросетевую модель и протестировать её. При использовании Python‑библиотеки <a href="https://huggingface.co/docs/transformers/main_classes/quantization" target="_blank">transformers</a> можно загрузить заранее обученную модель в 4-битном формате. Для этого достаточно установить в <code>True</code> параметр <code>load_in_4-bit</code>. Но будем честны: это не приблизит нас к пониманию того, как новые форматы чисел влияют на нейросетевые модели. Вместо этого прибегнем к «игрушечному» примеру — создадим маленькую нейросеть, обучим её и воспользуемся ей, применив 4-битные числа.</p>
  <p id="ThlR">Для начала создадим нейросетевую модель:</p>
  <pre id="grKW">import torch
import torch.nn as nn
import torch.optim as optim
from typing import Any

class NetNormal(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.model = nn.Sequential(
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )
      
    def forward(self, x):
        x = self.flatten(x)
        x = self.model(x)
        return F.log_softmax(x, dim=1)</pre>
  <p id="U2yr">Теперь надо подготовить загрузчик набора данных. Я буду использовать набор данных <a href="https://pytorch.org/vision/0.15/generated/torchvision.datasets.MNIST.html" target="_blank">MNIST</a>, содержащий 70000 изображений рукописных цифр размером 28x28 (авторские права на этот набор данных принадлежат Яну Лекуну и Коринне Кортез, он доступен по лицензии <a href="https://creativecommons.org/licenses/by-sa/3.0/" target="_blank">Creative Commons Attribution-Share Alike 3.0</a>). Набор данных разделён на две части — 60000 учебных и 10000 тестовых изображений. Выбор загружаемых данных может быть выполнен в загрузчике путём использования параметра <code>train=True|False</code>.</p>
  <pre id="U1Nr">from torchvision import datasets, transforms

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(&quot;data&quot;, train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(&quot;data&quot;, train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)</pre>
  <p id="aSzj">Теперь мы готовы к тому, чтобы обучить и сохранить модель. Процесс обучения выполняется «нормальным» способом, с применением стандартного формата чисел.</p>
  <pre id="WIgw">device = torch.device(&quot;cuda&quot;)

batch_size = 64
epochs = 4
log_interval = 500

def train(model: nn.Module, train_loader: torch.utils.data.DataLoader,
          optimizer: Any, epoch: int):
    &quot;&quot;&quot; Train the model &quot;&quot;&quot;
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print(f&#x27;Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}]\tLoss: {loss.item():.5f}&#x27;)
            
def test(model: nn.Module, test_loader: torch.utils.data.DataLoader):
    &quot;&quot;&quot; Test the model &quot;&quot;&quot;
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            t_start = time.monotonic()
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction=&#x27;sum&#x27;).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    t_diff = time.monotonic() - t_start

    print(f&quot;Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset)}%)\n&quot;)

def get_size_kb(model: nn.Module):
    &quot;&quot;&quot; Get model size in kilobytes &quot;&quot;&quot;
    size_model = 0
    for param in model.parameters():
        if param.data.is_floating_point():
            size_model += param.numel() * torch.finfo(param.data.dtype).bits
        else:
            size_model += param.numel() * torch.iinfo(param.data.dtype).bits
    print(f&quot;Model size: {size_model / (8*1024)} KB&quot;)

# Обучение
model = NetNormal().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(1, epochs + 1):
    train(model, train_loader, optimizer, epoch)
    test(model, test_loader)

get_size(model)

# Сохранение
torch.save(model.state_dict(), &quot;mnist_model.pt&quot;)</pre>
  <p id="zewV">Я, кроме того, написал вспомогательный метод <code>get_size_kb</code>, позволяющий узнать размер модели в килобайтах.</p>
  <p id="aDlX">Вот как выглядит процесс обучения модели:</p>
  <pre id="nsDA">Train Epoch: 1 [0/60000] Loss: 2.31558
Train Epoch: 1 [32000/60000] Loss: 0.53704
Test set: Average loss: 0.2684, Accuracy: 9225/10000 (92.25%)

Train Epoch: 2 [0/60000] Loss: 0.19791
Train Epoch: 2 [32000/60000] Loss: 0.17268
Test set: Average loss: 0.1998, Accuracy: 9401/10000 (94.01%)

Train Epoch: 3 [0/60000] Loss: 0.30570
Train Epoch: 3 [32000/60000] Loss: 0.33042
Test set: Average loss: 0.1614, Accuracy: 9530/10000 (95.3%)

Train Epoch: 4 [0/60000] Loss: 0.20046
Train Epoch: 4 [32000/60000] Loss: 0.19178
Test set: Average loss: 0.1376, Accuracy: 9601/10000 (96.01%)

Model size: 427.2890625 KB</pre>
  <p id="K91u">Наша простая модель достигла точности в 96%, размер нейронной сети — 427 Кб.</p>
  <p id="HvCY">А теперь — самое интересное! Создадим и протестируем 8-битную версию модели. Описание модели будет, на самом деле, таким же, как прежде. Я лишь заменил слой <code>Linear</code> на слой <code>Linear8bitLt</code>.</p>
  <pre id="QpHD">from bitsandbytes.nn import Linear8bitLt

class Net8Bit(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.model = nn.Sequential(
            Linear8bitLt(784, 128, has_fp16_weights=False),
            nn.ReLU(),
            Linear8bitLt(128, 64, has_fp16_weights=False),
            nn.ReLU(),
            Linear8bitLt(64, 10, has_fp16_weights=False)
        )
      
    def forward(self, x):
        x = self.flatten(x)
        x = self.model(x)
        return F.log_softmax(x, dim=1)

device = torch.device(&quot;cuda&quot;)

# Загрузка
model = Net8Bit()
model.load_state_dict(torch.load(&quot;mnist_model.pt&quot;))
get_size_kb(model)
print(model.model[0].weight)

# Преобразование
model = model.to(device)

get_size_kb(model)
print(model.model[0].weight)

# Запуск
test(model, test_loader)</pre>
  <p id="PUXB">Вот — выходные данные:</p>
  <pre id="vFd2">Model size: 427.2890625 KB
Parameter(Int8Params([[ 0.0071,  0.0059,  0.0146,  ...,  0.0111, -0.0041,  0.0025],
            ...,
            [-0.0131, -0.0093, -0.0016,  ..., -0.0156,  0.0042,  0.0296]]))

Model size: 107.4140625 KB
Parameter(Int8Params([[  9,   7,  19,  ...,  14,  -5,   3],
            ...,
            [-21, -15,  -3,  ..., -25,   7,  47]], device=&#x27;cuda:0&#x27;,
           dtype=torch.int8))

Test set: Average loss: 0.1347, Accuracy: 9600/10000 (96.0%)</pre>
  <p id="sVj6">Исходная модель была загружена с использованием стандартного формата чисел с плавающей запятой. Её размер остался таким же, веса выглядят как <code>[0.0071, 0.0059,…]</code>. Вся «магия» заключается в преобразовании модели в <code>cuda</code> — она становится в 4 раза меньше. Как видно, значения весов находятся в одном и том же диапазоне, поэтому преобразование модели сложностей не вызывает. В процессе проверки модели на тестовых данных оказалось, что она не потеряла ни единого процента точности!</p>
  <p id="YU8E">А теперь — 4-битная версия:</p>
  <pre id="qO3R">from bitsandbytes.nn import LinearFP4, LinearNF4

class Net4Bit(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.model = nn.Sequential(
            LinearFP4(784, 128),
            nn.ReLU(),
            LinearFP4(128, 64),
            nn.ReLU(),
            LinearFP4(64, 10)
        )
      
    def forward(self, x):
        x = self.flatten(x)
        x = self.model(x)
        return F.log_softmax(x, dim=1)

# Загрузка
model = Net4Bit()
model.load_state_dict(torch.load(&quot;mnist_model.pt&quot;))
get_model_size(model)
print(model.model[2].weight)

# Преобразование
model = model.to(device)

get_model_size(model)
print(model.model[2].weight)

# Запуск
test(model, test_loader)</pre>
  <p id="73JA">Вот — результаты работы:</p>
  <pre id="GheQ">Model size: 427.2890625 KB
Parameter(Params4bit([[ 0.0916, -0.0453,  0.0891,  ...,  0.0430, -0.1094, -0.0751],
            ...,
            [-0.0079, -0.1021, -0.0094,  ..., -0.0124,  0.0889,  0.0048]]))

Model size: 54.1015625 KB
Parameter(Params4bit([[ 95], [ 81], [109],
            ...,
            [ 34], [ 46], [ 33]], device=&#x27;cuda:0&#x27;, dtype=torch.uint8))

Test set: Average loss: 0.1414, Accuracy: 9579/10000 (95.79%)</pre>
  <p id="szlh">Мы получили интересные результаты. После преобразования размер модели уменьшился в 8 раз — с 427 до 54 Кб, но точность упала лишь на 1%. Как это возможно? Ответить на этот вопрос несложно. По крайней мере — для этой модели:</p>
  <ul id="wX7p">
    <li id="QAf9">Как видно, веса распределены более или менее равномерно, и потеря точности не слишком велика.</li>
    <li id="Bw0x">При обработке выходных данных в модели используется <a href="https://pytorch.org/docs/stable/generated/torch.nn.functional.log_softmax.html" target="_blank">Softmax</a>, результат определяется по индексу максимального значения. Несложно понять, что при поиске максимального индекса само значение роли не играет. Например — между 0,8 и 0,9 нет никакой разницы в том случае, если другие значения — это 0,1 или 0,2.</li>
  </ul>
  <p id="LJwq">Полагаю — важно более тщательно изучить то, что у нас получилось. Загрузим числа из тестового набора данных и ознакомимся с тем, что выдаст модель.</p>
  <pre id="Wag3">dataset = datasets.MNIST(&#x27;data&#x27;, train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))

np.set_printoptions(precision=3, suppress=True)  # Не использовать научную запись

data_in = dataset[4][0]
for x in range(28):
    for y in range(28):
        print(f&quot;{data_in[0][x][y]: .1f}&quot;, end=&quot; &quot;)
    print()</pre>
  <p id="DDvf">Вот — выведенное на экран число, которое нужно распознать:</p>
  <figure id="pxWy" class="m_custom">
    <img src="https://habrastorage.org/r/w1560/getpro/habr/upload_files/193/1b8/325/1931b83257f4954f94d3e46c403e5bb9.png" width="700" />
    <figcaption>Вывод данных</figcaption>
  </figure>
  <p id="4LGr">Посмотрим, что выдаст «стандартная» модель:</p>
  <pre id="IU1T"># Подавить научную запись
np.set_printoptions(precision=2, suppress=True)  

# Прогноз
with torch.no_grad():
    output = model(data_in.to(device))
    print(output[0].cpu().numpy())
    ind = output.argmax(dim=1, keepdim=True)[0].cpu().item()
    print(&quot;Result:&quot;, ind)

# &gt; [ -8.27 -13.89  -6.89 -11.13  -0.03  -8.09  -7.46  -7.6   -6.43  -3.77]
# &gt; Result: 4</pre>
  <p id="4ezF">Максимальный элемент находится в 5-й позиции (элементы в массивах numpy нумеруются с 0), что соответствует числу 4.</p>
  <p id="esAJ">Вот — результаты работы 8-битной модели:</p>
  <pre id="Nnfi"># &gt; [ -9.09 -12.66  -8.42 -12.2   -0.01  -9.25  -8.29  -7.26  -8.36  -4.45]
# &gt; Result: 4</pre>
  <p id="aY5z">Вот что выдала 4-битная модель:</p>
  <pre id="3wgV"># &gt; [ -8.56 -12.12  -7.52 -12.1   -0.01  -8.94  -7.84  -7.41  -7.31  -4.45]
# &gt; Result: 4</pre>
  <p id="5kTk">Хорошо видно, что реальные выходные значения у разных моделей различаются, но индекс максимального элемента остаётся одним и тем же.</p>
  <h2 id="5PWl">Итоги</h2>
  <p id="QsM6">В этой статье мы исследовали разные способы представления 16-битных, 8-битных и 4-битных чисел с плавающей запятой. Мы создали нейронную сеть и смогли запустить её с применением 8-битных и 4-битных чисел. И, на самом деле, за тем, как она работает, было интересно наблюдать. Уменьшая точность используемых чисел — со стандартной до 4-битной, нам удалось снизить объём памяти, необходимый модели, в 8 раз, при этом потеря точности оказалась минимальной. Конечно, мы экспериментировали на «игрушечном» примере, в по‑настоящему больших моделях используются более сложные механизмы (тем, кто интересуется данной темой, рекомендую <a href="https://huggingface.co/blog/hf-bitsandbytes-integration" target="_blank">этот</a> материал).</p>
  <p id="mEaO">Надеюсь, эта статья помогла вам получить представление об общих идеях, лежащих в основе вычислений с плавающей запятой. Как известно, «нужда — мать изобретений». Уменьшение объёма памяти, занимаемой моделью, в 4–8 раз — это замечательное достижение, особенно учитывая разницу в цене между видеокартами с памятью в 8, 16, 32 и 64 Гб ;).</p>
  <p id="pUcR">Кстати, даже 4 бита — это уже не предел. В публикации о <a href="https://arxiv.org/pdf/2210.17323.pdf" target="_blank">GTPQ</a> была упомянута возможность квантификации весов в 2 или даже в три (1,5 бита!) состояния. И последнее — по порядку, но не по важности: интересно поразмышлять о «точности» нейротрансмиттеров человеческого мозга. Интуитивно понятно, что она не так уж и высока. Возможно, 2- или 4-битные нейросетевые модели ближе, чем другие, к тем «моделям», которые находятся в наших головах.</p>
  <p id="i7i5">О, а приходите к нам работать? 🤗 💰</p>
  <p id="39bw">Теги:</p>
  <ul id="v15V">
    <li id="vmV7"><a href="https://habr.com/ru/search/?target_type=posts&order=relevance&q=%5BPython%5D" target="_blank">Python</a></li>
    <li id="pSi0"><a href="https://habr.com/ru/search/?target_type=posts&order=relevance&q=%5B%D1%80%D0%B0%D0%B7%D1%80%D0%B0%D0%B1%D0%BE%D1%82%D0%BA%D0%B0%5D" target="_blank">разработка</a></li>
    <li id="siTh"><a href="https://habr.com/ru/search/?target_type=posts&order=relevance&q=%5B%D1%87%D0%B8%D1%81%D0%BB%D0%B0%20%D1%81%20%D0%BF%D0%BB%D0%B0%D0%B2%D0%B0%D1%8E%D1%89%D0%B5%D0%B9%20%D0%B7%D0%B0%D0%BF%D1%8F%D1%82%D0%BE%D0%B9%5D" target="_blank">числа с плавающей запятой</a></li>
  </ul>
  <p id="nq5l">Хабы:</p>
  <ul id="8YR7">
    <li id="0Jwh"><a href="https://habr.com/ru/companies/wunderfund/articles/" target="_blank">Блог компании Wunder Fund</a></li>
    <li id="uQCs"><a href="https://habr.com/ru/hubs/webdev/" target="_blank">Веб-разработка</a></li>
    <li id="Kwyg"><a href="https://habr.com/ru/hubs/python/" target="_blank">Python</a></li>
    <li id="vVCP"><a href="https://habr.com/ru/hubs/programming/" target="_blank">Программирование</a></li>
  </ul>

]]></content:encoded></item></channel></rss>