December 11, 2023

Singleton is all you need (no)

На днях перед мной встала одна, на первый взгляд, простая задача для intern/junior разработчика на Python: написать singleton.

Сразу предотвращу горение тех, кто считает, что Singleton — это плохо. В общем случае да, и тебе нужно стараться их избегать. Но однако бывают случае, когда это уместно, даже несмотря на SOLID.


Итак, задача простая, давай писать. Заведем простой класс — пустышку:

class StubClass:
	pass

Наследование

Самая базовая идея, которая может придти в голову, реализовать singleton в Python через наследование.

Сделаем класс, который переопределяет магический метод __new__:

class Singleton:
    __instance = None

    def __new__(cls, *args, **kwargs):
        if cls.__instance is None:
            cls.__instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
        return cls.__instance

Тогда путем несложных программистких преобразований может получится следующий код:

class StubClass(Singleton):
	pass

Этот код ужасен. И не из-за каких-то свойств, которыми я наделил объект. Просто идейно. Ну т.е. давай здесь два класса с более человеческими названиями:

class Logger(Singleton): # какой-то общий логгер
	pass

class DBConnection(Singleton): # какое-то общее соединение
	pass

Концептуально, получается, что по иерархии наследования, это один и тот же класс. Но это же классы из разных вселенных. Если бы я ревьювил этот код, моя реакция была бы такой:

Ну сколько можно?

Метаклассы — I

Другим, более подходящим инструментом, для создания Singleton может послужить механизм метаклассов. Т.е. я наделяю свойствами не объекты какого-то типа, а сам тип.

Давай напишем такой мета-singleton:

class SingletonI(type):
    __instance = None

    def __call__(cls, *args, **kwargs):
        if cls.__instance is None:
            cls.__instance = super(SingletonI, cls).__call__(*args, **kwargs)
        return cls.__instance

Применяя к StubClass все свои 6 лет высшего образования получаю следующий код:

class StubClass(metaclass=SingletonI):
	pass

На этом можно было бы уже остановиться, но пока у меня есть несколько проблем. Дело в том, что метаклассы — такая сущность, которую лучше не менять лишний раз. А развивая систему нам может потребоваться многопоточная среда. Мой метакласс пока к этому не готов.

Однажды ко мне подошла мама и сказала, что вот сын её подруги такой Singleton запилил, что мне и не снилось. Пришлось исправляться.

Метаклассы — II

Самая первая идея, которая приходит, чтобы сделать класс thread-safe — наложить на него mutex. Тогда получится следующее:

import threading 

class SingletonII(type):
    __shared_instance = None
    __shared_instance_lock = threading.Lock()

    def __call__(cls, *args, **kwargs):
        with cls.__shared_instance_lock:
            if cls.__shared_instance is None:
                cls.__shared_instance = super(SingletonII, cls).__call__(*args, **kwargs)
        return cls.__shared_instance

И в целом это решение уже неплохое.

Если бы мой стажер пилил какую-то MVP-шечку, то моя реакция уже была бы примерно такой.

Есть одна проблема. Зачастую Singleton не хранят в каких-то объектах, а если нужно, то делают примерно так:

StubClass().make_some_action()

т.е. как бы создают новый объект, но т.к. это Singleton, то возвращается тот самый объект. Так вот, в случае глобального lock-объекта это будет дичайшим образом тормозить (с другой стороны, обучение нейронки будет работать дольше, а значит больше времени сделать кофе — решать тебе).

Метаклассы — III

Базовая идея ускорить прошлый Singleton — давай будем проверять, что
instance не является None, а если является, то там уже будем накладывать mutex, т.е. буквально поменяем строчки местами:

class SingletonIII(type):
    __shared_instance = None
    __shared_instance_lock = threading.Lock()

    def __call__(cls, *args, **kwargs):
        if cls.__shared_instance is None:
            with cls.__shared_instance_lock:
                cls.__shared_instance = super(SingletonIII, cls).__call__(*args, **kwargs)
        return cls.__shared_instance

Давай запустим примитивный бенчмарк (код для TimeitResult взят любезно отсюда):

import timeit
import sys
import math

setup = """import threading

class SingletonII(type):
    __shared_instance = None
    _shared_instance_lock = threading.Lock()

    def __call__(cls, *args, **kwargs):
        with cls._shared_instance_lock:
            if cls.__shared_instance is None:
                cls.__shared_instance = super(SingletonII, cls).__call__(*args, **kwargs)
        return cls.__shared_instance

class SingletonIII(type):
    __shared_instance = None
    __shared_instance_lock = threading.Lock()

    def __call__(cls, *args, **kwargs):
        if cls.__shared_instance is None:
            with cls.__shared_instance_lock:
                cls.__shared_instance = super(SingletonIII, cls).__call__(*args, **kwargs)
        return cls.__shared_instance

class StubII(metaclass=SingletonII):
    pass

class StubIII(metaclass=SingletonIII):
    pass

"""

def _format_time(timespan, precision=3):
    """Formats the timespan in a human readable form"""

    if timespan >= 60.0:
        # we have more than a minute, format that in a human readable form
        # Idea from http://snipplr.com/view/5713/
        parts = [("d", 60*60*24),("h", 60*60),("min", 60), ("s", 1)]
        time = []
        leftover = timespan
        for suffix, length in parts:
            value = int(leftover / length)
            if value > 0:
                leftover = leftover % length
                time.append(u'%s%s' % (str(value), suffix))
            if leftover < 1:
                break
        return " ".join(time)


    # Unfortunately the unicode 'micro' symbol can cause problems in
    # certain terminals.
    # See bug: https://bugs.launchpad.net/ipython/+bug/348466
    # Try to prevent crashes by being more secure than it needs to
    # E.g. eclipse is able to print a µ, but has no sys.stdout.encoding set.
    units = [u"s", u"ms",u'us',"ns"] # the save value
    if hasattr(sys.stdout, 'encoding') and sys.stdout.encoding:
        try:
            u'\xb5'.encode(sys.stdout.encoding)
            units = [u"s", u"ms",u'\xb5s',"ns"]
        except:
            pass
    scaling = [1, 1e3, 1e6, 1e9]

    if timespan > 0.0:
        order = min(-int(math.floor(math.log10(timespan)) // 3), 3)
    else:
        order = 3
    return "%.*g %s" % (precision, timespan * scaling[order], units[order])

class TimeitResult(object):
    """
    Object returned by the timeit magic with info about the run.

    Contains the following attributes :

    loops: (int) number of loops done per measurement
    repeat: (int) number of times the measurement has been repeated
    best: (float) best execution time / number
    all_runs: (list of float) execution time of each run (in s)
    compile_time: (float) time of statement compilation (s)

    """
    def __init__(self, loops, repeat, best, worst, all_runs, precision):
        self.loops = loops
        self.repeat = repeat
        self.best = best
        self.worst = worst
        self.all_runs = all_runs
        self._precision = precision
        self.timings = [ dt / self.loops for dt in all_runs]

    @property
    def average(self):
        return math.fsum(self.timings) / len(self.timings)

    @property
    def stdev(self):
        mean = self.average
        return (math.fsum([(x - mean) ** 2 for x in self.timings]) / len(self.timings)) ** 0.5

    def __str__(self):
        pm = '+-'
        if hasattr(sys.stdout, 'encoding') and sys.stdout.encoding:
            try:
                u'\xb1'.encode(sys.stdout.encoding)
                pm = u'\xb1'
            except:
                pass
        return "{mean} {pm} {std} per loop (mean {pm} std. dev. of {runs} run{run_plural}, {loops:,} loop{loop_plural} each)".format(
            pm=pm,
            runs=self.repeat,
            loops=self.loops,
            loop_plural="" if self.loops == 1 else "s",
            run_plural="" if self.repeat == 1 else "s",
            mean=_format_time(self.average, self._precision),
            std=_format_time(self.stdev, self._precision),
        )

    def _repr_pretty_(self, p , cycle):
        unic = self.__str__()
        p.text(u'<TimeitResult : '+unic+u'>')

def run_timeit(code, setup, repeat, number):
    res = timeit.repeat(code, setup=setup, repeat=repeat, number=number)
    print(TimeitResult(
        loops=number, repeat=repeat, best=max(res), worst=min(res), all_runs=res, precision=3
    ))

run_timeit("StubII()", setup=setup, repeat=10, number=1000000)
run_timeit("StubIII()", setup=setup, repeat=10, number=1000000)

Ожидаемый выход ускорения почти в 2.5 раза:

# run_timeit("StubII()", setup=setup, repeat=10, number=1000000)
335 ns ± 6.17 ns per loop (mean ± std. dev. of 10 runs, 1,000,000 loops each)
# run_timeit("StubIII()", setup=setup, repeat=10, number=1000000)
135 ns ± 4.01 ns per loop (mean ± std. dev. of 10 runs, 1,000,000 loops each)

Но есть проблемка. В моём коде образовался баг, который ведет к тому, что Singleton теперь может быть не single.

Это известная проблема имеет широкое освещение в курсах по многопоточке — Double-checking locking. Суть в том, что все потоки, которые заблокируются мьютексом, рано или поздно создадут свой instance.

Метаклассы — IV

Давай пофиксим, добавив еще одно условие уже внутри mutex:

class SingletonIV(type):
    __shared_instance = None
    __shared_instance_lock = threading.Lock()

    def __call__(cls, *args, **kwargs):
        if cls.__shared_instance is None:
            with cls.__shared_instance_lock:
                if cls.__shared_instance is None:
                    cls.__shared_instance = super(SingletonIV, cls).__call__(*args, **kwargs)
        return cls.__shared_instance

Чуть-чуть поменяв бенчмарк, делаем замеры:

# run_timeit("StubII()", setup=setup, repeat=10, number=1000000)
343 ns ± 5.34 ns per loop (mean ± std. dev. of 10 runs, 1,000,000 loops each)
# run_timeit("StubIII()", setup=setup, repeat=10, number=1000000)
146 ns ± 3.78 ns per loop (mean ± std. dev. of 10 runs, 1,000,000 loops each)
# run_timeit("StubIV()", setup=setup, repeat=10, number=1000000)
144 ns ± 4.14 ns per loop (mean ± std. dev. of 10 runs, 1,000,000 loops each)

Скорость почти не изменилась (в ускорение в Python могу поверить, но в большинстве других языков должна была быть просадка).

Давай тогда напишем быстро бенч на то, что у нас действительно нет так называемых data races:

import threading
from multiprocessing.pool import ThreadPool

class SingletonII(type):
    __shared_instance = None
    __shared_instance_lock = threading.Lock()

    def __call__(cls, *args, **kwargs):
        with cls.__shared_instance_lock:
            if cls.__shared_instance is None:
                cls.__shared_instance = super(SingletonII, cls).__call__(*args, **kwargs)
        return cls.__shared_instance

class SingletonIII(type):
    __shared_instance = None
    __shared_instance_lock = threading.Lock()

    def __call__(cls, *args, **kwargs):
        if cls.__shared_instance is None:
            with cls.__shared_instance_lock:
                cls.__shared_instance = super(SingletonIII, cls).__call__(*args, **kwargs)
        return cls.__shared_instance

class SingletonIV(type):
    __shared_instance = None
    __shared_instance_lock = threading.Lock()

    def __call__(cls, *args, **kwargs):
        if cls.__shared_instance is None:
            with cls.__shared_instance_lock:
                if cls.__shared_instance is None:
                    cls.__shared_instance = super(SingletonIV, cls).__call__(*args, **kwargs)
        return cls.__shared_instance


class ThreadObject(object):
    def __init__(self):
        super(ThreadObject, self).__init__()
        print("Created in thread: {}".format(threading.get_ident()))

class ThreadObjectSingletonII(ThreadObject, metaclass=SingletonII):
    pass

class ThreadObjectSingletonIII(ThreadObject, metaclass=SingletonIII):
    pass

class ThreadObjectSingletonIV(ThreadObject, metaclass=SingletonIV):
    pass


def run_test(thread_object_class):
    print(f"Run test for class: {thread_object_class.__name__}")
    def worker(_object_set, _object_set_access_lock):
        db = thread_object_class()
        with _object_set_access_lock:
            _object_set.add(db)

    object_set = set()
    object_set_access_lock = threading.Lock()

    thread_pool = ThreadPool(100)
    for _ in range(1000):
        thread_pool.apply_async(worker, (object_set, object_set_access_lock))
    thread_pool.close()
    thread_pool.join()

    print(f"Object count: {len(object_set)}")


if __name__ == "__main__":
    run_test(ThreadObjectSingletonII)
    run_test(ThreadObjectSingletonIII)
    run_test(ThreadObjectSingletonIV)

Если предыдущий скрипт запустить примерно 10 раз (просто рейзы не всегда выскакивают), видно, что IV и II реализации всегда имеют по одному объекту, тогда как реализация III может иметь несколько объектов.

Кстати, запустить можно быстро примерно так:

import os

for x in range(10):
    os.system("python3 path_to_data_race_check_script")

Итоги

Мораль? Да нет морали. Пишите код правильно. И да, вспомним классиков про метаклассы:

Metaclasses are deeper magic than 99% of users should ever worry about. If you wonder whether you need them, you don’t (the people who actually need them know with certainty that they need them, and don’t need an explanation about why).  — Tim Peters

P.S. Если я где-то неправ или набаговал — пиши, разберемся.