Статьи
December 8, 2022

Пишем "Змейку" в 12 строк кода на PyTorch

Давайте рассмотрим, как использовать линейную алгебру и тензорные операции, чтобы создать всем известную игру в 12 строк.

И у вас сразу точно возникает несколько вопросов:

1. Насколько длинные эти 12 строк?

Не волнуйтесь, все они соответствуют стандарту PEP 8.

2. Зачем это вообще делать?

Иногда надо писать код просто ради фана. Кроме того, это отличный способ познакомиться с PyTorch и возможностями, которые предоставляют тензоры.

3. Но этом же нет никакой практической пользы?
Напротив. Методы, используемые в этой материале, на самом деле являются фундаментальными. И они лежат в основе модуля TensorSnake, который может эмулировать параллельно 100 миллионов игр "Змейка" на карте NVIDIA A6000 с задержкой 20 миллисекунд.

Сегодня мы программируем версию "Змейки", в которой она может перетекать за границу поля и выходить с другой стороны. Тем не менее, можно будет изменить 2 строки, чтобы реализовать стандартную версию.

Будем использовать PyTorch и NumPy. Можно было использовать даже какую-то одну из библиотек, но у PyTorch прекрасное Tensor API, а в NumPy есть хорошая функция под названием unravel_index, которую мы и будем использовать.

И договоримся, что в подсчёт строк не будут входить импорты и строка с определением функции ;)

Вопросы закрыли, зафиксировали договорённости, поехали!

Кодировка

Важнейшей частью этого кода является кодировка состояния змейки — формализация хранения информации об её положении.

В результате кодировки нам хотелось бы получить матрицу, которая при выводе через plt.imshow покажет состояние игры. И эту информацию должно быть легко обновлять.

Поэтому всю игру мы представим в виде матрицы целых чисел, где каждая пустая ячейка в игре будет иметь значение 0, хвост змеи будет 1, и по мере приближения хвоста к голове значение клеток будет увеличиваться на ещё одну единицу. Места с едой (целью) определяются значением -1. Итак, для змеи размера N клеток хвост будет равен 1, а голова — N.

Теперь нам нужно как-то формализовать действия. Вместо традиционной кодировки действий [вверх, вправо, вниз, влево] мы будем использовать кодировку [влево, вперёд, вправо], для определения направления движения относительно текущего направления змейки. Игроку может быть не очень привычно, но такой подход не является избыточным, т.к. в каждый момент любое действие является валидным (поскольку змейка не может двигаться назад).

Реализация

Ну и наконец, пишем код.

Все используемые функции хорошо описаны в документацией PyTorch API, подглядывайте туда, если что-то не понимаете.

Получение текущей позиции

Первое, что нужно сделать — это получить текущее и предыдущее положение головы змеюки. Мы можем сделать это с помощью topk(2), так как голова всегда является самым большим целым числом, а предыдущая её позиция — второе по величине число. Единственная проблема, с которой мы сталкиваемся, заключается в том, что метод topk делает расчёт только по одному измерению. Поэтому нам нужно сначала разгладить тензор с помощью метода flatten(), получить максимальные k элементов, а затем использовать вышеупомянутый unravel_index, чтобы преобразовать его обратно в двухмерное состояние. И нам надо полученные два индекса в тензоры, чтобы мы могли выполнять математические вычисления и с ними.

Вычисление следующей позиции

Чтобы вычислить следующую позицию, мы сделаем pos_cur - pos_prev. Эта операция вернёт вектор, указывающий на текущее направление движения змеи. Далее мы хотим повернуть его, но насколько?

Мы хотим повернуть его на 270 + 90 * action градусов. Таким образом, когда мы будем передавать 0, то мы поворачиваем налево, 1 — мы двигаемся прямо, а 2 — поворачиваем направо.

Для получения результата мы применяем матрицу вращения. Если матрица применяется к самой себе, это даёт нам матрицу, которая эквивалентна двойному применению преобразования. Следовательно, мы можем взять вектор направления и применить матрицу вращения на 90 градусов против часовой стрелки T([[0, -1], [1, 0]]), возведённую в степень 3 + action.

Наконец, мы добавляем текущую позицию к этому новому вектору направления, чтобы получить следующую позицию. Затем мы берём новое местоположение и взятие остатка от деления на размером поля, чтобы создать функциональность "перетекания" змейки за границу.

Как умереть

Ах, извечный вопрос. Но пока мы о змейке.

Поскольку теперь у нас есть следующая позиция, становится довольно просто определить, должна ли змейка умереть или нет. Нам просто нужно проверить, является ли snake[tuple(pos_next)] > 0, так как единственными клетками со значениями больше 0 являются те, в которых в данный момент находится змея.

Если змейка умирает, мы хотим вернуть счёт текущей игры. Это также довольно просто, поскольку счёт в игре равен длине змейки минус 2 (предполагая, что мы начинаем игру при длине змеи 2). Чтобы получить длину, нам просто нужно получить значение pos_cur, так как это текущая голова змеи. Это означает, что текущий счет равен snake[tuple(pos_cur)] — 2.

Как кушать

Время собраться с духом, следующие 3 строчки — самые сложные в игре.

Чтобы проверить, поймала ли змейка еду, мы сравниваемsnake[pos_next] с -1. Если они равны, то нам нужно найти все позиции на доске, которые в данный момент равны 0. Это пустые ячейки, куда мы потенциально можем положить следующую цель.

Когда у нас будут все эти позиции, нам нужно случайным образом выбрать один из этих индексов и обновить его значение до -1. Нам не нужно редактировать текущую запись -1, так как змея перезапишет её при перемещении.

Чтобы найти все места, которые в данный момент равны 0, мы просто используем snake == 0 (это возвращает логический тензор). Далее мы делаем .multinomial(1) для того, чтобы выбрать одну из позиций наугад. Функция multinominal(n) выбирает n случайных индексов из тензора с вероятностью, основанной на значении элемента.

Однако multinomial работает только с одной размерностью (как и topk), а также принимает только значения с плавающей точкой. Следовательно, нам нужно сначала использовать методы flatten() и .to(t.float). Таким образом, каждый индекс, значение которого равно 0, имеет одинаковую вероятность выбора, а каждый индекс, значение которого не равно 0, имеет нулевую вероятность выбора.

Как только это будет сделано, нам нужно снова использовать unravel, чтобы вернуть всё к двухмерному состоянию и обновить тензор snake.

Как двигаться

Чтобы переместить змею, мы уменьшаем текущую змейку и добавляем новую голову на следующую позицию.

Однако, мы хотим уменьшить змею только в том случае, если змейка не поймала цель. Если же поймала, то мы хотим увеличить её размер на 1 ячейку.

Поэтому мы добавляем блок else к ветке if на случай уменьшения. Поскольку каждая ячейка змейки пронумерована, у хвоста значение 1, мы можем вычесть 1 из значения каждой ячейки, размер которой больше 0 (так как только ячейки самой змейки вообще больше нуля). Вот мы и подрезали змейку на 1 клетку.

Теперь нам нужно добавить ей голову на новую позицию, т.е. установить значение следующей позиции, как значение предыдущей +1.

Заключение

А вот и всё. Вы написали "Змейку" в 12 строк кода.

def do(snake: t.Tensor, action: int):
    positions = snake.flatten().topk(2)[1]
    [pos_cur, pos_prev] = [T(unravel(x, snake.shape)) for x in positions]
    rotation = T([[0, -1], [1, 0]]).matrix_power(3 + action)
    pos_next = (pos_cur + (pos_cur - pos_prev) @ rotation) % T(snake.shape)
    
    if (snake[tuple(pos_next)] > 0).any():
        return (snake[tuple(pos_cur)] - 2).item() 
    
    if snake[tuple(pos_next)] == -1:
        pos_food = (snake == 0).flatten().to(t.float).multinomial(1)[0]
        snake[unravel(pos_food, snake.shape)] = -1
    else:
        snake[snake > 0] -= 1  
        
    snake[tuple(pos_next)] = snake[tuple(pos_cur)] + 1

Интерфейс

А, дак вы и поиграть в неё еще хотите?

Создание простенького графического интерфейса будет стоить нам ещё 15 строк, держите:

snake = t.zeros((32, 32), dtype=t.int)
snake[0, :3] = T([1, 2, -1]) 

fig, ax = plt.subplots(1, 1)
img = ax.imshow(snake)
action = {'val': 1}
action_dict = {'a': 0, 'd': 2}

fig.canvas.mpl_connect('key_press_event',
                       lambda e: action.__setitem__('val', action_dict[e.key]))

score = None
while score is None: 
    img.set_data(snake)
    fig.canvas.draw_idle()
    plt.pause(0.1) 
    score = do(snake, action['val']) 
    action['val'] = 1 
    
print('Score:', score)

Теперь можете играть сколько душе угодно :)

PythonTalk в Telegram

Чат PythonTalk в Telegram

Предложить материал | Поддержать канал

Источник: Elias F. Fyksen