Source code for wrapyfi.utils.serialization_encoders

import abc
import io
import json
import base64
from datetime import datetime

import numpy as np

from wrapyfi.utils.core_utils import *


[docs] class JsonEncoder(json.JSONEncoder): """ A custom JSON encoder that can encode: - Sets - Datetime objects - NumPy datetime64 objects - NumPy ndarray objects - Objects registered with the PluginRegistrar """
[docs] def __init__(self, **kwargs): """ Initialize the JsonEncoder. :param kwargs: dict: Additional keyword arguments extracting values from the 'serializer_kwargs' key and passing them to the base class. All other keyword arguments are passed to the corresponding Plugin. """ super().__init__(**kwargs.get("serializer_kwargs", {})) self.plugins = dict() for plugin_key, plugin_val in PluginRegistrar.encoder_registry.items(): self.plugins[plugin_key] = plugin_val(**kwargs)
[docs] def find_plugin(self, obj): """ Find the plugin for a given object. :param obj: Any: The object to find the plugin for :return: Plugin: The plugin for the given object if its type is registered, None otherwise """ for cls in reversed(type(obj).__mro__[:-1]): if cls.__module__ == "collections.abc": continue # skip classes from collections.abc if issubclass(cls, abc.ABCMeta): if cls.__abstractmethods__: continue # skip abstract classes with abstract methods return self.plugins.get(cls, None) return None
[docs] def encode(self, obj): """ Encode an object into a JSON string and ensure that tuples are not encoded as lists. :param obj: Any: The object to encode :return: str: The JSON string representation of the object returned by the base class """ def hint_tuples(item): if isinstance(item, tuple): return dict(__wrapyfi__=("tuple", item)) if isinstance(item, list): return [hint_tuples(e) for e in item] if isinstance(item, dict): return {key: hint_tuples(value) for key, value in item.items()} else: return item return super(JsonEncoder, self).encode(hint_tuples(obj))
[docs] def default(self, obj): """ The default method for the JSON encoder. This method pre-processes the object before encoding it. :param obj: Any: The object to encode :return: dict: A dictionary containing the class name and encoded data string """ if isinstance(obj, set): return dict(__wrapyfi__=("set", list(obj))) elif isinstance(obj, datetime): return dict(__wrapyfi__=("datetime", obj.isoformat())) elif isinstance(obj, np.datetime64): return dict(__wrapyfi__=("numpy.datetime64", str(obj))) elif isinstance(obj, (np.ndarray, np.generic)): with io.BytesIO() as memfile: np.save(memfile, obj) obj_data = base64.b64encode(memfile.getvalue()).decode("ascii") return dict(__wrapyfi__=("numpy.ndarray", obj_data)) plugin_match = self.find_plugin(obj) if plugin_match is not None: detected, plugin_return = plugin_match.encode(obj) if detected: return plugin_return # Let the base class default method raise the TypeError return json.JSONEncoder.default(self, obj)
[docs] class JsonDecodeHook(object): """ A custom JSON decoder hook that can decode: - Tuples - Sets - Datetime objects - NumPy datetime64 objects - NumPy ndarray objects - Objects registered with the PluginRegistrar """
[docs] def __init__(self, **kwargs): """ Initialize the JsonDecodeHook. :param kwargs: dict: Additional keyword arguments are passed to the corresponding Plugin. """ self.plugins = dict() for plugin_key, plugin_val in PluginRegistrar.decoder_registry.items(): self.plugins[plugin_key] = plugin_val(**kwargs)
[docs] def object_hook(self, obj): """ The object hook for the JSON decoder. This method post-processes the object after decoding it. :param obj: Any: The object to decode if the object is a dictionary containing the class name and encoded data string :return: Any: The decoded object """ if isinstance(obj, dict): wrapyfi = obj.get("__wrapyfi__", None) if wrapyfi is not None: obj_type = wrapyfi[0] if obj_type == "tuple": return tuple(wrapyfi[1]) elif obj_type == "set": return set(wrapyfi[1]) elif obj_type == "datetime": return datetime.fromisoformat(wrapyfi[1]) elif obj_type == "numpy.datetime64": return np.datetime64(wrapyfi[1]) elif obj_type == "numpy.ndarray": with io.BytesIO( base64.b64decode(wrapyfi[1].encode("ascii")) ) as memfile: return np.load(memfile) plugin_match = self.plugins.get(obj_type, None) if plugin_match is not None: detected, plugin_return = plugin_match.decode(obj_type, wrapyfi) if detected: return plugin_return return obj