python
May 23, 2023

Сериализация кастомного класса

При работе с пользовательскими классами зачастую возникает необходимость их сериализации. При этом, если системе не сообщить, как она должна происходить, вы получите уведомление об ошибке (откуда Python знать, что вы хотите сохранять). Расскажу, как действовать в этом случае и создать свой быстрый способ сериализации. В качестве структуры данных для сериализации рассмотрим список пользовательских объектов namedtuple, задающих модели и колонки, к которым они применяются:

import os
import zipfile
import shutil

from collections import namedtuple
from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import LinearRegression

Model = namedtuple('complex_model', ['name', 'columns', 'model'])
model1 = Model('tree', ['col1', 'col2'], LinearRegression())
model2 = Model('linear', ['col1', 'col3'], DecisionTreeRegressor())

steps = [model1, model2]

Если просто попытаться сохранить объект steps, получите ошибку:

import joblib
joblib.dump(steps, 'steps')

Для создания своей схемы сериализации надо определить, какая информация нужна для однозначного восстановления вашей структуры и сохранить ее по частям. В данном случае нам нужно для каждого элемента списка отдельно сохранить словарь параметров и их значений, а также готовые экземпляры моделей (просто названий мало, так как они могут быть уже обучены). С учетом большого количества файлов, можно сохранить их в папку, а затем заархивировать ее. Данный архив и будет формой хранения состояния нашего объекта:

dir_n = 'agg_params'
shutil.rmtree(dir_n, ignore_errors=True)
os.mkdir(dir_n)

for i, step in enumerate(steps):
    joblib.dump({k:v for k,v in step._asdict().items() if k!='model'}, f'{dir_n}/params_{i}')
    joblib.dump(step._asdict()['model'], f'{dir_n}/model_{i}')

Для дампа "знакомых" Python объектов используется модуль joblib, а после заархивируем их с модулем zipfile:

with zipfile.ZipFile(f'{dir_n}.zip', 'w', compression=zipfile.ZIP_DEFLATED) as zip_f:
    for filename in os.listdir(dir_n):
           zip_f.write(os.path.join(dir_n, filename), arcname=filename)

shutil.rmtree(dir_n)

Восстановление осуществляется в обратном порядке. Сначала файлы извлекаются из архива в папку:

arch_fn = f'{dir_n}.zip'

zip_f = zipfile.ZipFile(arch_fn)
zip_f.extractall(arch_fn[:-4])
os.remove(arch_fn)

А затем все с тем же joblib дампы извлекаются в структуры Python:

steps = []
for i in range(len(os.listdir(arch_fn[:-4]))//2):
    params = joblib.load(os.path.join(arch_fn[:-4], f'params_{i}'))
    model = joblib.load(os.path.join(arch_fn[:-4], f'model_{i}'))
    steps.append(Model(params['name'], params['columns'], model))
shutil.rmtree(arch_fn[:-4]) 

Вывод объекта steps свидетельствует о том, что сериализация-десериализация прошли корректно:

steps