diff --git a/preserve.py b/preserve.py index b067c64..9b76d92 100644 --- a/preserve.py +++ b/preserve.py @@ -11,164 +11,453 @@ or storage system - most of which handle all primitive types by default. """ -from enum import Enum, auto +# pylint: disable=bad-staticmethod-argument import inspect +from functools import partial +from collections import namedtuple +from types import SimpleNamespace, MappingProxyType -class RestoreMethod(Enum): - DIRECT = auto() - INIT = auto() - CLASS_METHOD = auto() +class RestoreMethod(): + """ + Contains the various supplied restore method implementations. All methods here are intended + to be passed in when creating a PreservableType. + + All methods must use the same argument set: `(cls, state)`, where `cls` is the class usually + stored as the `type` field in the Preservable, and `state` is whatever was previously + returned by `preserve()`. The value returned by the method is intended to be a new copy of the + instance originally passed to `preserve()`. + """ + @staticmethod + def default_restore(cls, state): + """ + The default restore method used by Preserve. Creates a new instance of the class (without + calling `__init__()`), and directly updates the object `__dict__` with the state passed + in. + + If the object has implemented a `__restore_init__()` method, that is then called. + """ + obj = cls.__new__(cls) + obj.__dict__.update(state) + if inspect.ismethod(getattr(obj, "__restore_init__", None)): + obj.__restore_init__() + return obj + + @staticmethod + def restore_after_init(cls, state): + """ + Does the same things as the default restore method, but calls `__init__()` on the new + object first, without any arguments. Useful when you want to retain existing init + behaviour but override some of the object attributes immediately afterward from the + preserved state. Often used along with `include_attrs` when you don't want to + implement `__restore_init__()`. + + Will not call `__restore_init__()` on the object, even if it is implemented. + """ + obj = cls() + obj.__dict__.update(state) + return obj + + @staticmethod + def setstate(cls, state): + """ + Uses the `__setstate__(self, state)` method on the class - the same protocol used for + things like Pickle. + """ + obj = cls.__new__(cls) + obj.__setstate__(state) + return obj + + @staticmethod + def pass_to_new(cls, state): + """ + Creates the new instance by directly passing `state` to the `__new__` method on `cls`. + Useful for immutable types (like tuples) that are initialised this way. + """ + obj = cls.__new__(cls, state) + return obj + + +class RestoreMethodGenerator(): + """ + Similar to RestoreMethod, but contains functions intended to be _called_ that will _return_ a + RestoreMethod. These will usually be a partial function with a given set of arguments to + allow customisation of the restore method. + """ + + @staticmethod + def restore_after_init_args(*args, **keywords): + """ + Returns a PreserveMethod similar to `restore_after_init()`, but allows you to pass in + a set of args and kwargs that will be passed on to the `__init__()` call. + """ + def restore_after_init_args(cls, state, *args, **keywords): + obj = cls(*args, **keywords) + obj.__dict__.update(state) + return obj + return partial(restore_after_init_args, *args, **keywords) + + +class PreserveMethod(): + """ + Contains the various supplied preserve method implementations. Methods here are intended + to be passed directly in when creating a PreservableType. + + All methods must only accept a single argument, `obj`: the object that needs to be preserved. + """ + + @staticmethod + def default_preserve(obj): + """ + The default preserve method used by Preserve. Simply grabs the `__dict__` attribute of the + object (returned by `vars(obj)`). Will only work if all object attributes are themselves + preservable. + """ + return vars(obj) + + @staticmethod + def getstate(obj): + """ + Uses the `__getstate__(obj)` _method_ on the class - the same protocol used for + things like Pickle. + """ + return obj.__getstate__() + + +class PreserveMethodGenerator(): + """ + Similar to PreserveMethod, but contains functions intended to be _called_ that will _return_ a + PreserveMethod. These will usually be a partial function with a given set of arguments to + allow customisation of the preserve method. + """ + + @staticmethod + def default_preserve_filter(include_attrs=None, exclude_attrs=None): + """ + Returns a PreserveMethod similar to `default_preserve()` that will filter out certain + attributes before returning them in the `state` dictionary. Both arguments should be + iterables of argument names (strings). + + Passing None to `include_attrs` will skip the inclusion filter. + """ + def default_preserve_filter(obj, include_attrs, exclude_attrs): + state = vars(obj) + attr_keys = state.keys() + if include_attrs: + include_attrs = set(include_attrs) + attr_keys = attr_keys & include_attrs + if exclude_attrs: + exclude_attrs = set(exclude_attrs) + attr_keys = attr_keys - exclude_attrs + return {key: state[key] for key in attr_keys} -class_key = "<_jam>" + return partial(default_preserve_filter, include_attrs=include_attrs, + exclude_attrs=exclude_attrs) + + @staticmethod + def coerce(new_type): + """ + Returns a PreserveMethod that simply passes the object into `new_type` to try and + create a new object of that type. Useful for simplifying objects when preserving them. + """ + def coerce(obj, new_type): + return new_type(obj) + return partial(coerce, new_type=new_type) + + +rm = RestoreMethod +rm_gen = RestoreMethodGenerator +pm = PreserveMethod +pm_gen = PreserveMethodGenerator + + +escape_key = "<_jam>" +""" +The string used as a dict key in the preserved output to indicate Preserve metadata. + +Set to <_jam>` by default. +""" + +escaped_state_key = "<_jam_state>" +""" +The string used as a dict key in the preserved output to indicate Preserve object state. + +Set to <_jam_state>` by default. +""" + +PreservableType = namedtuple('PreservableRecord', ['type', + 'name', + 'preserve_method', + 'restore_method']) + +common = SimpleNamespace() +common.tuple = PreservableType(tuple, 'tuple', pm_gen.coerce(list), rm.pass_to_new) + + +_preserve_types = {} +""" +An internal dict of PreservableTypes that can be preserved. Keys are the type. """ -Contains the reserved dict key used to indicate that the dict it is in -should be restored to a class instance, not just treated as a dict. -Set to ``<_jam>`` by default. If changed externally, must be set before -any ``@preservable`` decorators +_restore_types = {} +""" +An internal dict of PreservableTypes that can be restored. Keys are the names. """ +preservables = MappingProxyType(_restore_types) +""" +A dict of all PreservableTypes Preserve is currently able to handle, with the names as keys. +""" -preservables = {} +raw_preservables = [dict, list, str, int, float, bool, type(None)] """ -A dict of classes marked as ``@preservable``, used to restore them back -to class instances. +A list of types that Preserve will pass directly through to the preserved output. Note that both +`dict` and `list` are special cases, must always be present in this list (attempting to remove +them will break things). """ -def preservable(cls, restore_method=RestoreMethod.DIRECT, name=None): +def register(preservable_type): + """ + Registers an existing PerservableType for use with Preserve. Intended to be used with provided + PreservableTypes from somewhere like preserve.common. + + To add a custom preservable class, instead consider `add_preservable()` or + the `@preservable` class decorator. + + Args: + preservable_type: An instance of preserve.PreservableType + """ + _preserve_types[preservable_type.type] = preservable_type + _restore_types[preservable_type.name] = preservable_type + + +def add_preservable(preservable_type, name, preserve_method, restore_method): + """ + Creates a new PreservableType with the details given and registers it with Preserve. """ - Decorator to mark class as preservable and keep track of associated names - and classes. + if preservable_type in _preserve_types: + raise Exception(F"Preservable type already registered with type {preservable_type}") + + if type(name) != str: + raise Exception("Preservable name must be a string") + + if name in _restore_types: + raise Exception(F"Preservable type already registered with name '{name}'") + + register(PreservableType(preservable_type, name, preserve_method, restore_method)) + + +def preservable(cls=None, *, include_attrs=None, exclude_attrs=None, preserve_method=None, + restore_method=None, name=None): + """ + Class decorator to register a class as preservable. Creates and adds a corresponding + PreservableType to Preserve. Usable either raw (`@preservable` has sane defaults) or with + arguments (`@preservable(exclude_attrs = 'my_attr')`). + + The default preserve method attempts to directly preserve all object attributes, optionally + filtering them if `include_attrs` or `exclude_attrs` are present. If `preserve_method` is set, + `include_attrs` and `exclude_attrs` are ignored. + + The default restore method creates a new instance of the class - without calling __init__() - + and directly restores any preserved object attributes. If the restored object implements it, + `__restore_init__()` is then called. + + If any arguments to this decorator remain None, it will look for corresponding class + attributes on the class - this allows subclasses to inherit Preserve behaviour + where desirable. + Class attributes: `_preserve_include_attrs`, `_preserve_exclude_attrs`, `_preserve_method`, + `_restore_method`, `_preserve_name` + + If `__getstate__()` or `__setstate__()` are implemented in the class (the Pickle protocol), + they will override the default preserve or restore method, respectively. Args: - restore_method: One of the available preserve.RestoreMethod values. - Sets the method used for restoring this class. Defaults to - ``DIRECT``, skipping the ``__init__`` method and setting all - attributes as they were. - name: The string used to indentify this class in the preserved dict. - Must be unique among all ``@preservable`` classes. Defaults to the + include_attrs: A list of attribute names to use when preserving the object. Only used if + the preserve_method is left as default. Leaving this as none includes all available + attributes. + exclude_attrs: A list of attribute names to leave out when preserving the object. Only + used if the preserve_method is left as default. + preserve_method: A callable to be used to preserve this preservable. Some common ones and + details on their requirements are provided in `preserve.PreserveMethod`. + Leave as None for default behaviour. + restore_method: A callable to be used to restore this preservable. Some common ones and + details on their requirements are provided in `preserve.RestoreMethod` + Leave as None for default behaviour. + name: The string used to indentify this class in the preserved output. + Must be unique among all preservables. Defaults to the class name if left as None. """ + # If called with kwargs rather than as a direct decorator, we need to return + # a decorator that Python will only pass `cls` into: + if cls is None: + return partial(preservable, include_attrs=include_attrs, + exclude_attrs=exclude_attrs, preserve_method=preserve_method, + restore_method=restore_method, name=name) + + if include_attrs is None: + include_attrs = getattr(cls, '_preserve_include_attrs', None) + if exclude_attrs is None: + exclude_attrs = getattr(cls, '_preserve_exclude_attrs', None) + if preserve_method is None: + preserve_method = getattr(cls, '_preserve_method', None) + if restore_method is None: + restore_method = getattr(cls, '_restore_method', None) if name is None: - cls._preserve_name = cls.__name__ - else: - if type(name) != str: - raise Exception("Preservable name must be a string") - cls._preserve_name = name - cls._restore_method = restore_method - - if cls._preserve_name in preservables: - raise Exception("Duplicate preservable class name "+cls._preserve_name) - preservables[cls._preserve_name] = cls - - def _preserve(self): - dict_jam = _preserve_dict(vars(self)) - dict_jam[class_key] = self._preserve_name - return dict_jam + name = getattr(cls, '_preserve_name', None) + + if name is None: + name = cls.__name__ + + if preserve_method is None: + if hasattr(cls, '__getstate__'): + preserve_method = pm.getstate + else: + if include_attrs or exclude_attrs: + # String iterable gaurds + if isinstance(include_attrs, str): + include_attrs = (include_attrs,) + if isinstance(exclude_attrs, str): + exclude_attrs = (exclude_attrs,) + + preserve_method = pm_gen.default_preserve_filter(include_attrs, exclude_attrs) + else: + preserve_method = pm.default_preserve + + if restore_method is None: + if hasattr(cls, '__setstate__'): + restore_method = rm.setstate + else: + restore_method = rm.default_restore + + add_preservable(cls, name, preserve_method, restore_method) - cls.preserve = _preserve return cls def preserve(target_obj): """ - Preserve ``target_obj``, running through its contents recursively. + Preserve target_obj, running through its contents recursively. + + Pass the result back to `restore()` to get a copy of the original object and its content. Args: target_obj: The object to be preserved. This object and all its nested - contents must either be primitive types or objects of a - ``@preservable`` class. + contents must be preservable. Returns: - The preserved data structure - a nested structure containing only primitive - types. + The preserved data structure - a nested structure containing only types + from `raw_preservables` (by default dict, list, str, int, float, bool, type(None)) """ - # If it's a primitive, store it. If it's a dict or list, recursively preserve that. - # If it's an instance of another preservable class, call its .preserve() method. obj_type = type(target_obj) - if obj_type in (str, int, float, bool, type(None)): - return target_obj - elif obj_type == dict: - return _preserve_dict(target_obj) + + if obj_type == dict: + dict_jam = {} + for key, val in target_obj.items(): + if type(key) != str: + raise Exception("Non-string dictionary keys are not preservable") + if key in (escape_key, escaped_state_key): + raise Exception(F"Dict key '{key}' is reserved as an internal escape key") + dict_jam[key] = preserve(val) + return dict_jam + elif obj_type == list: list_jam = [] for val in target_obj: list_jam.append(preserve(val)) return list_jam - elif hasattr(target_obj, "_preserve_name"): - return target_obj.preserve() + + elif obj_type in raw_preservables: + return target_obj + + elif obj_type in _preserve_types: + return _preserve_preservable(target_obj, _preserve_types[obj_type]) + else: + raise Exception(F"Object {target_obj} is not preservable") + + +def _preserve_preservable(target_obj, preservable_type): + """ + Internal preserve function specifically to deal with PreservableTypes. Result will always be + at minimum a dict with an escape key to hold metadata. + """ + obj_state = preserve(preservable_type.preserve_method(target_obj)) + + # This dict is here for future use to store other metadata. + escaped_metadata = {} + + if (type(obj_state) == dict): + # Collapse state dict + obj_jam = obj_state else: - raise Exception("Object "+str(target_obj)+" is not preservable") + obj_jam = {escaped_state_key: obj_state} + if len(escaped_metadata) == 0: + # Collapse metadata dict + escaped_metadata = preservable_type.name + else: + escaped_metadata['name'] = preservable_type.name + obj_jam[escape_key] = escaped_metadata -def _preserve_dict(target_dict): - dict_jam = {} - for k, val in target_dict.items(): - if type(k) != str: - raise Exception("Non-string dictionary keys are not preservable") - if k == class_key: - raise Exception("Key "+class_key+" is reserved for internal use") - dict_jam[k] = preserve(val) - return dict_jam + return obj_jam def restore(obj_jam): """ - Restore the result of ``preserve()`` back into its original form. Will - recursively scan the data structure and restore any - ``@preservable`` classes according to their ``restore_method``. + Restore the result of `preserve()` back into its original form. Args: obj_jam: The data structure to restore, usually the result of a - ``preserve()`` call. + `preserve()` call. """ obj_type = type(obj_jam) - if obj_type in (str, int, float, bool, type(None)): - return obj_jam - elif obj_type == dict: + + if obj_type == dict: return _restore_dict(obj_jam) + elif obj_type == list: restored_list = [] for val in obj_jam: restored_list.append(restore(val)) return restored_list + + elif obj_type in raw_preservables: + return obj_jam + else: - raise Exception("Object "+str(obj_jam)+" is not restorable") + raise Exception(F"Object {str(obj_jam)} is not restorable") def _restore_dict(dict_jam): restored_dict = {} - for k, val in dict_jam.items(): - if type(k) != str: + for key, val in dict_jam.items(): + if type(key) != str: raise Exception("Non-string dictionary keys are not restorable") - if k != class_key: - restored_dict[k] = restore(val) - - # Check if this is an object that needs to be restored back to a class instance - if class_key in dict_jam: - if dict_jam[class_key] not in preservables: - raise Exception("Class "+dict_jam[class_key]+" has not been decorated as preservable") - f_class = preservables[dict_jam[class_key]] - # If DIRECT, skip __init__ and set attributes back directly - if f_class._restore_method == RestoreMethod.DIRECT: - restored_instance = f_class.__new__(f_class) - restored_instance.__dict__.update(restored_dict) - # if INIT, pass all attributes as keywords to __init__ method - elif f_class._restore_method == RestoreMethod.INIT: - restored_instance = f_class(**restored_dict) - # IF CLASS_METHOD, pass all attributes as keyword aguments to classmethod "unpack()" - elif f_class._restore_method == RestoreMethod.CLASS_METHOD: - if inspect.ismethod(getattr(f_class, "restore", None)): - restored_instance = f_class.restore(**restored_dict) - else: - raise Exception("Class "+str(f_class)+" does not have classmethod 'restore()'") + if key not in (escape_key, escaped_state_key): + restored_dict[key] = restore(val) + + if escape_key in dict_jam: + metadata = dict_jam[escape_key] + if type(metadata) == dict: + name = metadata['name'] + else: + # Metadata was collapsed + name = metadata + + if name not in _restore_types: + raise Exception(F"PreservableType with name '{name}' has not been" + " registered with Preserve") + + if escaped_state_key in dict_jam: + obj_state = restore(dict_jam[escaped_state_key]) else: - raise Exception("Class _restore_method " + - str(f_class._restore_method)+" is not supported") + # Object state was a dict and collapsed + obj_state = restored_dict + + preservable_type = _restore_types[name] + return preservable_type.restore_method(preservable_type.type, obj_state) - return restored_instance else: + # A plain dict return restored_dict diff --git a/setup.py b/setup.py index 0086545..8a00a29 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,7 @@ from setuptools import setup setup(name='preserve', - version='0.3dev', + version='0.4dev', description='Preserve, for when pickling is just a bit too intense', url='https://git.distreon.net/novirium/python-preserve', author='novirium', diff --git a/tests/test_preserve.py b/tests/test_preserve.py index 5821dc0..6d02252 100644 --- a/tests/test_preserve.py +++ b/tests/test_preserve.py @@ -1,22 +1,8 @@ +# pylint: disable=redefined-outer-name, attribute-defined-outside-init import pytest import preserve -class PlainClass: - def __init__(self): - self.attr_int = 734 - self.attr_string = "I'm a test string" - self.attr_float = 42.085 - self.attr_dict = {"key1": 1, "key2": "val2"} - self.attr_list = ["item1", 2, 3, 4] - self.attr_bool = True - - -@preserve.preservable -class PreservableClass(PlainClass): - pass - - @pytest.fixture def primitive_data(): return {'key1': ["item1", {"d2k1": "val1", "d2k2": 2}, @@ -36,6 +22,21 @@ def test_primitive_data(primitive_data): assert (preserve.restore(preserved_data)) == primitive_data +class PlainClass: + def __init__(self): + self.attr_int = 734 + self.attr_string = "I'm a test string" + self.attr_float = 42.085 + self.attr_dict = {"key1": 1, "key2": "val2"} + self.attr_list = ["item1", 2, 3, 4] + self.attr_bool = True + + +@preserve.preservable +class PreservableClass(PlainClass): + pass + + def test_data_with_preservable(primitive_data): # Preservable class within a data structure primitive_data["preservable"] = PreservableClass() @@ -45,6 +46,100 @@ def test_data_with_preservable(primitive_data): assert obj_attrs_and_type_equal(restored_data["preservable"], primitive_data["preservable"]) +@preserve.preservable(exclude_attrs=['attr_int']) +class PreservableClass_Args1(PlainClass): + pass + + +@preserve.preservable(include_attrs=['attr_float'], name='other_name') +class PreservableClass_Args2(PlainClass): + pass + + +def test_preservable_args(): + restored = preserve.restore(preserve.preserve(PreservableClass_Args1())) + assert not hasattr(restored, 'attr_int') + + preserved = preserve.preserve(PreservableClass_Args2()) + restored = preserve.restore(preserved) + assert preserved[preserve.escape_key] == 'other_name' + assert len(vars(restored)) == 1 + assert restored.attr_float == 42.085 + + +@preserve.preservable +class PreservableClass_AttrArgs(PlainClass): + _preserve_include_attrs = ['attr_string'] + pass + + +@preserve.preservable +class PreservableClass_AttrArgs2(PreservableClass_AttrArgs): + pass + + +def test_preservable_class_attr_args(): + restored = preserve.restore(preserve.preserve(PreservableClass_AttrArgs())) + assert len(vars(restored)) == 1 + assert restored.attr_string == "I'm a test string" + + # The class attribute args should propogate to the subclass + restored = preserve.restore(preserve.preserve(PreservableClass_AttrArgs2())) + assert len(vars(restored)) == 1 + assert restored.attr_string == "I'm a test string" + assert type(restored) == PreservableClass_AttrArgs2 + + +@preserve.preservable +class PreservableClass_RestoreInit(PlainClass): + def __restore_init__(self): + self.attr_int2 = 987 + + +def test_restore_init(): + restored = preserve.restore(preserve.preserve(PreservableClass_RestoreInit())) + assert restored.attr_int2 == 987 + + +@preserve.preservable +class PreservableClass_PickleProtocol(PlainClass): + def __setstate__(self, state): + self.only_attr = state + + def __getstate__(self): + return "abcd" + + +def test_pickle_protocol(): + restored = preserve.restore(preserve.preserve(PreservableClass_PickleProtocol())) + assert len(vars(restored)) == 1 + assert restored.only_attr == "abcd" + + +def test_non_dict_preservable_state(): + preserved = preserve.preserve(PreservableClass_PickleProtocol()) + assert preserved[preserve.escaped_state_key] == "abcd" + + +def test_common_preservable_tuple(primitive_data): + primitive_data["tuple"] = (1, 2, 3) + with pytest.raises(Exception, match="is not preservable"): + preserve.restore(preserve.preserve(primitive_data)) + + preserve.register(preserve.common.tuple) + restored = preserve.restore(preserve.preserve(primitive_data)) + assert restored["tuple"] == (1, 2, 3) + + +def test_subclass_of_preservable(): + class PreservableClassSubclass(PreservableClass): + pass + + # A subclass _shouldn't_ be preservable by default, as if it was restore() would + # return an instance of the parent class. + with pytest.raises(Exception, match="is not preservable"): + preserve.restore(preserve.preserve(PreservableClassSubclass())) + def test_subclass_of_raw_preservable(): class RawPreservableSubclass(dict): @@ -78,10 +173,10 @@ def test_attr_plain_class(): preserve.preserve(obj) -def test_class_key(): +def test_escape_key(): # Should be able to change the default class key before decorators # and have preserve/restore work - old_class_key = preserve.class_key + old_escape_key = preserve.escape_key preserve.class_key = "A different key" @preserve.preservable @@ -93,13 +188,15 @@ def test_class_key(): assert restored_obj != obj assert obj_attrs_and_type_equal(restored_obj, obj) - preserve.class_key = old_class_key + preserve.escape_key = old_escape_key -def test_class_key_in_data(): +def test_escape_key_in_data(): # Can't use the class key as dict key being preserved - with pytest.raises(Exception, match="reserved for internal use"): - preserve.preserve({preserve.class_key: 1}) + with pytest.raises(Exception, match="reserved as an internal escape key"): + preserve.preserve({preserve.escape_key: 1}) + with pytest.raises(Exception, match="reserved as an internal escape key"): + preserve.preserve({preserve.escaped_state_key: 1}) def test_unrestorable(primitive_data):