Add support for shared objects in keras when (de)serializing.

Previously, objects shared between multiple layers would be duplicated when
saving and loading.  This commit adds a unique ID for keras objects when
serializing that allows us to correctly create only a single instance of
shared objects when deserializing.

PiperOrigin-RevId: 353123479
Change-Id: I057ba61fe587d5cb97238fb31ca562dc5791cf88
This commit is contained in:
A. Unique TensorFlower 2021-01-21 16:00:59 -08:00 committed by TensorFlower Gardener
parent a7d6c05bd6
commit 2d3d381d5f
9 changed files with 53 additions and 483 deletions

View File

@ -671,13 +671,12 @@ class Functional(training_lib.Model):
Raises:
ValueError: In case of improperly formatted config dict.
"""
with generic_utils.SharedObjectLoadingScope():
input_tensors, output_tensors, created_layers = reconstruct_from_config(
config, custom_objects)
model = cls(inputs=input_tensors, outputs=output_tensors,
name=config.get('name'))
connect_ancillary_layers(model, created_layers)
return model
input_tensors, output_tensors, created_layers = reconstruct_from_config(
config, custom_objects)
model = cls(inputs=input_tensors, outputs=output_tensors,
name=config.get('name'))
connect_ancillary_layers(model, created_layers)
return model
def _validate_graph_inputs_and_outputs(self):
"""Validates the inputs and outputs of a Graph Network."""
@ -1347,23 +1346,21 @@ def get_network_config(network, serialize_layer_fn=None):
node_conversion_map[node_key] = kept_nodes
kept_nodes += 1
layer_configs = []
for layer in network.layers: # From the earliest layers on.
filtered_inbound_nodes = []
for original_node_index, node in enumerate(layer._inbound_nodes):
node_key = _make_node_key(layer.name, original_node_index)
if node_key in network._network_nodes and not node.is_input:
# The node is relevant to the model:
# add to filtered_inbound_nodes.
node_data = node.serialize(_make_node_key, node_conversion_map)
filtered_inbound_nodes.append(node_data)
with generic_utils.SharedObjectSavingScope():
for layer in network.layers: # From the earliest layers on.
filtered_inbound_nodes = []
for original_node_index, node in enumerate(layer._inbound_nodes):
node_key = _make_node_key(layer.name, original_node_index)
if node_key in network._network_nodes and not node.is_input:
# The node is relevant to the model:
# add to filtered_inbound_nodes.
node_data = node.serialize(_make_node_key, node_conversion_map)
filtered_inbound_nodes.append(node_data)
layer_config = serialize_layer_fn(layer)
layer_config['name'] = layer.name
layer_config['inbound_nodes'] = filtered_inbound_nodes
layer_configs.append(layer_config)
config['layers'] = layer_configs
layer_config = serialize_layer_fn(layer)
layer_config['name'] = layer.name
layer_config['inbound_nodes'] = filtered_inbound_nodes
layer_configs.append(layer_config)
config['layers'] = layer_configs
# Gather info about inputs and outputs.
model_inputs = []

View File

@ -393,10 +393,6 @@ def clone_model(model, input_tensors=None, clone_function=None):
except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers.
`clone_model` will not preserve the uniqueness of shared objects within the
model (e.g. a single variable attached to two distinct layers will be
restored as two separate variables).
Args:
model: Instance of `Model`
(could be a functional model or a Sequential model).

View File

@ -148,9 +148,8 @@ def save_model(model,
hdf5_format.save_model_to_hdf5(
model, filepath, overwrite, include_optimizer)
else:
with generic_utils.SharedObjectSavingScope():
saved_model_save.save(model, filepath, overwrite, include_optimizer,
signatures, options, save_traces)
saved_model_save.save(model, filepath, overwrite, include_optimizer,
signatures, options, save_traces)
@keras_export('keras.models.load_model')
@ -195,18 +194,17 @@ def load_model(filepath, custom_objects=None, compile=True, options=None): # py
ImportError: if loading from an hdf5 file and h5py is not available.
IOError: In case of an invalid savefile.
"""
with generic_utils.SharedObjectLoadingScope():
with generic_utils.CustomObjectScope(custom_objects or {}):
with load_context.load_context(options):
if (h5py is not None and
(isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
compile)
with generic_utils.CustomObjectScope(custom_objects or {}):
with load_context.load_context(options):
if (h5py is not None and
(isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
compile)
filepath = path_to_string(filepath)
if isinstance(filepath, six.string_types):
loader_impl.parse_saved_model(filepath)
return saved_model_load.load(filepath, compile, options)
filepath = path_to_string(filepath)
if isinstance(filepath, six.string_types):
loader_impl.parse_saved_model(filepath)
return saved_model_load.load(filepath, compile, options)
raise IOError(
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import os
import shutil
import sys
@ -26,14 +25,12 @@ import tempfile
from absl.testing import parameterized
import numpy as np
from six import string_types
from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.feature_column import feature_column_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.keras import combinations
@ -862,125 +859,6 @@ class TestWholeModelSaving(keras_parameterized.TestCase):
self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
expected)
@combinations.generate(combinations.combine(mode=['eager']))
def test_shared_objects(self):
class OuterLayer(keras.layers.Layer):
def __init__(self, inner_layer):
super(OuterLayer, self).__init__()
self.inner_layer = inner_layer
def call(self, inputs):
return self.inner_layer(inputs)
def get_config(self):
return {
'inner_layer': generic_utils.serialize_keras_object(
self.inner_layer)
}
@classmethod
def from_config(cls, config):
return cls(generic_utils.deserialize_keras_object(
config['inner_layer']))
class InnerLayer(keras.layers.Layer):
def __init__(self):
super(InnerLayer, self).__init__()
self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32)
def call(self, inputs):
return self.v + inputs
@classmethod
def from_config(cls, config):
return cls()
# Create a model with 2 output layers that share the same inner layer.
inner_layer = InnerLayer()
outer_layer_1 = OuterLayer(inner_layer)
outer_layer_2 = OuterLayer(inner_layer)
input_ = keras.Input(shape=(1,))
model = keras.Model(
inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)])
# Changes to the shared layer should affect both outputs.
model.layers[1].inner_layer.v.assign(5)
self.assertAllEqual(model(1), [6.0, 6.0])
model.layers[1].inner_layer.v.assign(3)
self.assertAllEqual(model(1), [4.0, 4.0])
# After loading, changes to the shared layer should still affect both
# outputs.
def _do_assertions(loaded):
loaded.layers[1].inner_layer.v.assign(5)
self.assertAllEqual(loaded(1), [6.0, 6.0])
loaded.layers[1].inner_layer.v.assign(3)
self.assertAllEqual(loaded(1), [4.0, 4.0])
loaded.layers[2].inner_layer.v.assign(5)
self.assertAllEqual(loaded(1), [6.0, 6.0])
loaded.layers[2].inner_layer.v.assign(3)
self.assertAllEqual(loaded(1), [4.0, 4.0])
# We'd like to make sure we only attach shared object IDs when strictly
# necessary, so we'll recursively traverse the generated config to count
# whether we have the exact number we expect.
def _get_all_keys_recursive(dict_or_iterable):
if isinstance(dict_or_iterable, dict):
for key in dict_or_iterable.keys():
yield key
for key in _get_all_keys_recursive(dict_or_iterable.values()):
yield key
elif isinstance(dict_or_iterable, string_types):
return
else:
try:
for item in dict_or_iterable:
for key in _get_all_keys_recursive(item):
yield key
# Not an iterable or dictionary
except TypeError:
return
with generic_utils.CustomObjectScope({
'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}):
# Test saving and loading to disk
save_format = testing_utils.get_save_format()
saved_model_dir = self._save_model_dir()
keras.models.save_model(model, saved_model_dir, save_format=save_format)
loaded = keras.models.load_model(saved_model_dir)
_do_assertions(loaded)
# Test recreating directly from config
config = model.get_config()
key_count = collections.Counter(_get_all_keys_recursive(config))
self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2)
loaded = keras.Model.from_config(config)
_do_assertions(loaded)
@combinations.generate(combinations.combine(mode=['eager']))
def test_shared_objects_wrapper(self):
"""Tests that shared layers wrapped with `Wrapper` restore correctly."""
input_ = keras.Input(shape=(1,))
unwrapped = keras.layers.Layer(name='unwrapped')
wrapped = keras.layers.Wrapper(unwrapped, name='wrapped')
model = keras.Model(inputs=input_,
outputs=[unwrapped(input_), wrapped(input_)])
# Test recreating directly from config
config = model.get_config()
loaded = keras.Model.from_config(config)
self.assertIs(loaded.layers[1], loaded.layers[2].layer)
# Test saving and loading to disk
save_format = testing_utils.get_save_format()
saved_model_dir = self._save_model_dir()
keras.models.save_model(model, saved_model_dir, save_format=save_format)
loaded = keras.models.load_model(saved_model_dir)
self.assertIs(loaded.layers[1], loaded.layers[2].layer)
# Factory functions to create models that will be serialized inside a Network.
def _make_graph_network(input_size, output_size):

View File

@ -46,6 +46,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
# the python config serialization has caught up.
metadata = dict(
class_name=generic_utils.get_registered_name(type(self.obj)),
name=self.obj.name,
trainable=self.obj.trainable,
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
@ -55,7 +56,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
)
metadata.update(get_serialized(self.obj))
metadata.update(get_config(self.obj))
if self.obj.input_spec is not None:
# Layer's input_spec has already been type-checked in the property setter.
metadata['input_spec'] = nest.map_structure(
@ -109,12 +110,16 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
# TODO(kathywu): Move serialization utils (and related utils from
# generic_utils.py) to a separate file.
def get_serialized(obj):
def get_config(obj):
with generic_utils.skip_failed_serialization():
# Store the config dictionary, which may be used when reviving the object.
# When loading, the program will attempt to revive the object from config,
# and if that fails, the object will be revived from the SavedModel.
return generic_utils.serialize_keras_object(obj)
config = generic_utils.serialize_keras_object(obj)['config']
if config is not None:
return {'config': config}
return {}
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):

View File

@ -492,15 +492,13 @@ class KerasObjectLoader(object):
# found.
class_name = metadata.get('class_name')
config = metadata.get('config')
shared_object_id = metadata.get('shared_object_id')
must_restore_from_config = metadata.get('must_restore_from_config')
if not generic_utils.validate_config(config):
return None
try:
obj = layers_module.deserialize(
generic_utils.serialize_keras_class_and_config(
class_name, config, shared_object_id=shared_object_id))
generic_utils.serialize_keras_class_and_config(class_name, config))
except ValueError:
if must_restore_from_config:
raise RuntimeError(

View File

@ -36,7 +36,7 @@ class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver):
class_name=generic_utils.get_registered_name(type(self.obj)),
name=self.obj.name,
dtype=self.obj.dtype)
metadata.update(layer_serialization.get_serialized(self.obj))
metadata.update(layer_serialization.get_config(self.obj))
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
return metadata

View File

@ -24,10 +24,8 @@ import marshal
import os
import re
import sys
import threading
import time
import types as python_types
import weakref
import numpy as np
import six
@ -112,205 +110,9 @@ def get_custom_objects():
return _GLOBAL_CUSTOM_OBJECTS
# Store a unique, per-object ID for shared objects.
#
# We store a unique ID for each object so that we may, at loading time,
# re-create the network properly. Without this ID, we would have no way of
# determining whether a config is a description of a new object that
# should be created or is merely a reference to an already-created object.
SHARED_OBJECT_KEY = 'shared_object_id'
class NoopLoadingScope(object):
"""The default shared object loading scope. It does nothing.
Created to simplify serialization code that doesn't care about shared objects
(e.g. when serializing a single object).
"""
def get(self, unused_object_id):
return None
def set(self, object_id, obj):
pass
SHARED_OBJECT_LOADING = threading.local()
def _shared_object_loading_scope():
"""Get the current shared object saving scope in a threadsafe manner.
Attributes on the threadlocal variable must be set per-thread, thus we
cannot initialize these globally.
Returns:
A SharedObjectLoadingScope or NoopLoadingScope object.
"""
return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
class SharedObjectLoadingScope(object):
"""A context manager for keeping track of loaded objects.
During the deserialization process, we may come across objects that are
shared across multiple layers. In order to accurately restore the network
structure to its original state, `SharedObjectLoadingScope` allows us to
re-use shared objects rather than cloning them.
"""
def __enter__(self):
global SHARED_OBJECT_LOADING
SHARED_OBJECT_LOADING.scope = self
self._obj_ids_to_obj = {}
return self
def get(self, object_id):
"""Given a shared object ID, returns a previously instantiated object.
Args:
object_id: shared object ID to use when attempting to find already-loaded
object.
Returns:
The object, if we've seen this ID before. Else, `None`.
"""
# Explicitly check for `None` internally to make external calling code a
# bit cleaner.
if object_id is None:
return
return self._obj_ids_to_obj.get(object_id)
def set(self, object_id, obj):
"""Stores an instantiated object for future lookup and sharing."""
if object_id is None:
return
self._obj_ids_to_obj[object_id] = obj
def __exit__(self, *args, **kwargs):
global SHARED_OBJECT_LOADING
SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
SHARED_OBJECT_SAVING = threading.local()
def _shared_object_saving_scope():
"""Get the current shared object saving scope in a threadsafe manner.
Attributes on the threadlocal variable must be set per-thread, thus we
cannot initialize these globally.
Returns:
A SharedObjectSavingScope object or None.
"""
return getattr(SHARED_OBJECT_SAVING, 'scope', None)
class SharedObjectConfig(dict):
"""A configuration container that keeps track of references.
`SharedObjectConfig` will automatically attach a shared object ID to any
configs which are referenced more than once, allowing for proper shared
object reconstruction at load time.
In most cases, it would be more proper to subclass something like
`collections.UserDict` or `collections.Mapping` rather than `dict` directly.
Unfortunately, python's json encoder does not support `Mapping`s. This is
important functionality to retain, since we are dealing with serialization.
We should be safe to subclass `dict` here, since we aren't actually
overriding any core methods, only augmenting with a new one for reference
counting.
"""
def __init__(self, base_config, object_id, **kwargs):
self.ref_count = 1
self.object_id = object_id
super(SharedObjectConfig, self).__init__(base_config, **kwargs)
def increment_ref_count(self):
# As soon as we've seen the object more than once, we want to attach the
# shared object ID. This allows us to only attach the shared object ID when
# it's strictly necessary, making backwards compatibility breakage less
# likely.
if self.ref_count == 1:
self[SHARED_OBJECT_KEY] = self.object_id
self.ref_count += 1
class SharedObjectSavingScope(object):
"""Keeps track of shared object configs when serializing."""
def __enter__(self):
global SHARED_OBJECT_SAVING
# Serialization can happen at a number of layers for a number of reasons.
# We may end up with a case where we're opening a saving scope within
# another saving scope. In that case, we'd like to use the outermost scope
# available and ignore inner scopes, since there is not (yet) a reasonable
# use case for having these nested and distinct.
if _shared_object_saving_scope() is not None:
self._passthrough = True
return _shared_object_saving_scope()
else:
self._passthrough = False
SHARED_OBJECT_SAVING.scope = self
self._shared_objects_config = weakref.WeakKeyDictionary()
self._next_id = 0
return self
def get_config(self, obj):
"""Gets a `SharedObjectConfig` if one has already been seen for `obj`.
Args:
obj: The object for which to retrieve the `SharedObjectConfig`.
Returns:
The SharedObjectConfig for a given object, if already seen. Else,
`None`.
"""
if obj in self._shared_objects_config:
shared_object_config = self._shared_objects_config[obj]
shared_object_config.increment_ref_count()
return shared_object_config
def create_config(self, base_config, obj):
shared_object_config = SharedObjectConfig(base_config, self._next_id)
self._next_id += 1
self._shared_objects_config[obj] = shared_object_config
return shared_object_config
def __exit__(self, *args, **kwargs):
if not self._passthrough:
global SHARED_OBJECT_SAVING
SHARED_OBJECT_SAVING.scope = None
def serialize_keras_class_and_config(
cls_name, cls_config, obj=None, shared_object_id=None):
def serialize_keras_class_and_config(cls_name, cls_config):
"""Returns the serialization of the class with the given config."""
base_config = {'class_name': cls_name, 'config': cls_config}
# We call `serialize_keras_class_and_config` for some branches of the load
# path. In that case, we may already have a shared object ID we'd like to
# retain.
if shared_object_id is not None:
base_config[SHARED_OBJECT_KEY] = shared_object_id
# If we have an active `SharedObjectSavingScope`, check whether we've already
# serialized this config. If so, just use that config. This will store an
# extra ID field in the config, allowing us to re-create the shared object
# relationship at load time.
if _shared_object_saving_scope() is not None and obj is not None:
shared_object_config = _shared_object_saving_scope().get_config(obj)
if shared_object_config is None:
return _shared_object_saving_scope().create_config(base_config, obj)
return shared_object_config
return base_config
return {'class_name': cls_name, 'config': cls_config}
@keras_export('keras.utils.register_keras_serializable')
@ -432,19 +234,7 @@ def get_registered_object(name, custom_objects=None, module_objects=None):
@keras_export('keras.utils.serialize_keras_object')
def serialize_keras_object(instance):
"""Serialize a Keras object into a JSON-compatible representation.
Calls to `serialize_keras_object` while underneath the
`SharedObjectSavingScope` context manager will cause any objects re-used
across multiple layers to be saved with a special shared object ID. This
allows the network to be re-created properly during deserialization.
Args:
instance: The object to serialize.
Returns:
A dict-like, JSON-compatible representation of the object's config.
"""
"""Serialize a Keras object into a JSON-compatible representation."""
_, instance = tf_decorator.unwrap(instance)
if instance is None:
return None
@ -475,8 +265,7 @@ def serialize_keras_object(instance):
serialization_config[key] = item
name = get_registered_name(instance.__class__)
return serialize_keras_class_and_config(
name, serialization_config, instance)
return serialize_keras_class_and_config(name, serialization_config)
if hasattr(instance, '__name__'):
return get_registered_name(instance)
raise ValueError('Cannot serialize', instance)
@ -497,9 +286,8 @@ def class_and_config_for_serialized_keras_object(
custom_objects=None,
printable_module_name='object'):
"""Returns the class name and config for a serialized keras object."""
if (not isinstance(config, dict)
or 'class_name' not in config
or 'config' not in config):
if (not isinstance(config, dict) or 'class_name' not in config or
'config' not in config):
raise ValueError('Improper config format: ' + str(config))
class_name = config['class_name']
@ -553,24 +341,7 @@ def deserialize_keras_object(identifier,
module_objects=None,
custom_objects=None,
printable_module_name='object'):
"""Turns the serialized form of a Keras object back into an actual object.
Calls to `deserialize_keras_object` while underneath the
`SharedObjectLoadingScope` context manager will cause any already-seen shared
objects to be returned as-is rather than creating a new object.
Args:
identifier: the serialized form of the object.
module_objects: A dictionary of custom objects to look the name up in.
Generally, module_objects is provided by midlevel library implementers.
custom_objects: A dictionary of custom objects to look the name up in.
Generally, custom_objects is provided by the user.
printable_module_name: A human-readable string representing the type of the
object. Printed in case of exception.
Returns:
The deserialized object.
"""
"""Turns the serialized form of a Keras object back into an actual object."""
if identifier is None:
return None
@ -580,39 +351,25 @@ def deserialize_keras_object(identifier,
(cls, cls_config) = class_and_config_for_serialized_keras_object(
config, module_objects, custom_objects, printable_module_name)
# If this object has already been loaded (i.e. it's shared between multiple
# objects), return the already-loaded object.
shared_object_id = config.get(SHARED_OBJECT_KEY)
shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none
if shared_object is not None:
return shared_object
if hasattr(cls, 'from_config'):
arg_spec = tf_inspect.getfullargspec(cls.from_config)
custom_objects = custom_objects or {}
if 'custom_objects' in arg_spec.args:
deserialized_obj = cls.from_config(
return cls.from_config(
cls_config,
custom_objects=dict(
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
list(custom_objects.items())))
else:
with CustomObjectScope(custom_objects):
deserialized_obj = cls.from_config(cls_config)
with CustomObjectScope(custom_objects):
return cls.from_config(cls_config)
else:
# Then `cls` may be a function returning a class.
# in this case by convention `config` holds
# the kwargs of the function.
custom_objects = custom_objects or {}
with CustomObjectScope(custom_objects):
deserialized_obj = cls(**cls_config)
# Add object to shared objects, in case we find it referenced again.
_shared_object_loading_scope().set(shared_object_id, deserialized_obj)
return deserialized_obj
return cls(**cls_config)
elif isinstance(identifier, six.string_types):
object_name = identifier
if custom_objects and object_name in custom_objects:

View File

@ -23,7 +23,6 @@ from functools import partial
import numpy as np
from tensorflow.python import keras
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.platform import test
@ -385,63 +384,5 @@ class SliceArraysTest(test.TestCase):
[None, None, None])
# object() alone isn't compatible with WeakKeyDictionary, which we use to
# track shared configs.
class MaybeSharedObject(object):
pass
class SharedObjectScopeTest(test.TestCase):
def test_shared_object_saving_scope_single_object_doesnt_export_id(self):
with generic_utils.SharedObjectSavingScope() as scope:
single_object = MaybeSharedObject()
self.assertIsNone(scope.get_config(single_object))
single_object_config = scope.create_config({}, single_object)
self.assertIsNotNone(single_object_config)
self.assertNotIn(generic_utils.SHARED_OBJECT_KEY,
single_object_config)
def test_shared_object_saving_scope_shared_object_exports_id(self):
with generic_utils.SharedObjectSavingScope() as scope:
shared_object = MaybeSharedObject()
self.assertIsNone(scope.get_config(shared_object))
scope.create_config({}, shared_object)
first_object_config = scope.get_config(shared_object)
second_object_config = scope.get_config(shared_object)
self.assertIn(generic_utils.SHARED_OBJECT_KEY,
first_object_config)
self.assertIn(generic_utils.SHARED_OBJECT_KEY,
second_object_config)
self.assertIs(first_object_config, second_object_config)
def test_shared_object_loading_scope_noop(self):
# Test that, without a context manager scope, adding configs will do
# nothing.
obj_id = 1
obj = MaybeSharedObject()
generic_utils._shared_object_loading_scope().set(obj_id, obj)
self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id))
def test_shared_object_loading_scope_returns_shared_obj(self):
obj_id = 1
obj = MaybeSharedObject()
with generic_utils.SharedObjectLoadingScope() as scope:
scope.set(obj_id, obj)
self.assertIs(scope.get(obj_id), obj)
def test_nested_shared_object_saving_scopes(self):
my_obj = MaybeSharedObject()
with generic_utils.SharedObjectSavingScope() as scope_1:
scope_1.create_config({}, my_obj)
with generic_utils.SharedObjectSavingScope() as scope_2:
# Nesting saving scopes should return the original scope and should
# not clear any objects we're tracking.
self.assertIs(scope_1, scope_2)
self.assertIsNotNone(scope_2.get_config(my_obj))
self.assertIsNotNone(scope_1.get_config(my_obj))
self.assertIsNone(generic_utils._shared_object_saving_scope())
if __name__ == '__main__':
test.main()