diff --git a/preserve.py b/preserve.py index 5614a1f..b067c64 100644 --- a/preserve.py +++ b/preserve.py @@ -55,7 +55,7 @@ def preservable(cls, restore_method=RestoreMethod.DIRECT, name=None): if name is None: cls._preserve_name = cls.__name__ else: - if isinstance(name, str): + if type(name) != str: raise Exception("Preservable name must be a string") cls._preserve_name = name cls._restore_method = restore_method @@ -88,11 +88,12 @@ def preserve(target_obj): """ # If it's a primitive, store it. If it's a dict or list, recursively preserve that. # If it's an instance of another preservable class, call its .preserve() method. - if isinstance(target_obj, (str, int, float, bool, type(None))): + obj_type = type(target_obj) + if obj_type in (str, int, float, bool, type(None)): return target_obj - elif isinstance(target_obj, dict): + elif obj_type == dict: return _preserve_dict(target_obj) - elif isinstance(target_obj, list): + elif obj_type == list: list_jam = [] for val in target_obj: list_jam.append(preserve(val)) @@ -106,7 +107,7 @@ def preserve(target_obj): def _preserve_dict(target_dict): dict_jam = {} for k, val in target_dict.items(): - if not isinstance(k, str): + if type(k) != str: raise Exception("Non-string dictionary keys are not preservable") if k == class_key: raise Exception("Key "+class_key+" is reserved for internal use") @@ -124,11 +125,12 @@ def restore(obj_jam): obj_jam: The data structure to restore, usually the result of a ``preserve()`` call. """ - if isinstance(obj_jam, (str, int, float, bool, type(None))): + obj_type = type(obj_jam) + if obj_type in (str, int, float, bool, type(None)): return obj_jam - elif isinstance(obj_jam, dict): + elif obj_type == dict: return _restore_dict(obj_jam) - elif isinstance(obj_jam, list): + elif obj_type == list: restored_list = [] for val in obj_jam: restored_list.append(restore(val)) @@ -140,7 +142,7 @@ def restore(obj_jam): def _restore_dict(dict_jam): restored_dict = {} for k, val in dict_jam.items(): - if not isinstance(k, str): + if type(k) != str: raise Exception("Non-string dictionary keys are not restorable") if k != class_key: restored_dict[k] = restore(val) diff --git a/tests/test_preserve.py b/tests/test_preserve.py index da39d8f..5821dc0 100644 --- a/tests/test_preserve.py +++ b/tests/test_preserve.py @@ -45,6 +45,17 @@ def test_data_with_preservable(primitive_data): assert obj_attrs_and_type_equal(restored_data["preservable"], primitive_data["preservable"]) + +def test_subclass_of_raw_preservable(): + class RawPreservableSubclass(dict): + pass + + # A subclass _shouldn't_ be preservable by default, as if it was restore() would + # return an instance of the parent class. + with pytest.raises(Exception, match="is not preservable"): + preserve.restore(preserve.preserve(RawPreservableSubclass())) + + def test_plain_class(): # plain class is not preservable with pytest.raises(Exception, match="is not preservable"):