diff --git a/tensorflow/python/keras/engine/functional.py b/tensorflow/python/keras/engine/functional.py index 743b4c05434..a3aa26540b1 100644 --- a/tensorflow/python/keras/engine/functional.py +++ b/tensorflow/python/keras/engine/functional.py @@ -671,12 +671,13 @@ class Functional(training_lib.Model): Raises: ValueError: In case of improperly formatted config dict. """ - input_tensors, output_tensors, created_layers = reconstruct_from_config( - config, custom_objects) - model = cls(inputs=input_tensors, outputs=output_tensors, - name=config.get('name')) - connect_ancillary_layers(model, created_layers) - return model + with generic_utils.SharedObjectLoadingScope(): + input_tensors, output_tensors, created_layers = reconstruct_from_config( + config, custom_objects) + model = cls(inputs=input_tensors, outputs=output_tensors, + name=config.get('name')) + connect_ancillary_layers(model, created_layers) + return model def _validate_graph_inputs_and_outputs(self): """Validates the inputs and outputs of a Graph Network.""" @@ -1346,21 +1347,23 @@ def get_network_config(network, serialize_layer_fn=None): node_conversion_map[node_key] = kept_nodes kept_nodes += 1 layer_configs = [] - for layer in network.layers: # From the earliest layers on. - filtered_inbound_nodes = [] - for original_node_index, node in enumerate(layer._inbound_nodes): - node_key = _make_node_key(layer.name, original_node_index) - if node_key in network._network_nodes and not node.is_input: - # The node is relevant to the model: - # add to filtered_inbound_nodes. - node_data = node.serialize(_make_node_key, node_conversion_map) - filtered_inbound_nodes.append(node_data) - layer_config = serialize_layer_fn(layer) - layer_config['name'] = layer.name - layer_config['inbound_nodes'] = filtered_inbound_nodes - layer_configs.append(layer_config) - config['layers'] = layer_configs + with generic_utils.SharedObjectSavingScope(): + for layer in network.layers: # From the earliest layers on. + filtered_inbound_nodes = [] + for original_node_index, node in enumerate(layer._inbound_nodes): + node_key = _make_node_key(layer.name, original_node_index) + if node_key in network._network_nodes and not node.is_input: + # The node is relevant to the model: + # add to filtered_inbound_nodes. + node_data = node.serialize(_make_node_key, node_conversion_map) + filtered_inbound_nodes.append(node_data) + + layer_config = serialize_layer_fn(layer) + layer_config['name'] = layer.name + layer_config['inbound_nodes'] = filtered_inbound_nodes + layer_configs.append(layer_config) + config['layers'] = layer_configs # Gather info about inputs and outputs. model_inputs = [] diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py index b16e0d6fb60..0b19f4e2236 100644 --- a/tensorflow/python/keras/models.py +++ b/tensorflow/python/keras/models.py @@ -393,6 +393,10 @@ def clone_model(model, input_tensors=None, clone_function=None): except that it creates new layers (and thus new weights) instead of sharing the weights of the existing layers. + `clone_model` will not preserve the uniqueness of shared objects within the + model (e.g. a single variable attached to two distinct layers will be + restored as two separate variables). + Args: model: Instance of `Model` (could be a functional model or a Sequential model). @@ -420,15 +424,16 @@ def clone_model(model, input_tensors=None, clone_function=None): Raises: ValueError: in case of invalid `model` argument value. """ - if clone_function is None: - clone_function = _clone_layer + with generic_utils.DisableSharedObjectScope(): + if clone_function is None: + clone_function = _clone_layer - if isinstance(model, Sequential): - return _clone_sequential_model( - model, input_tensors=input_tensors, layer_fn=clone_function) - else: - return _clone_functional_model( - model, input_tensors=input_tensors, layer_fn=clone_function) + if isinstance(model, Sequential): + return _clone_sequential_model( + model, input_tensors=input_tensors, layer_fn=clone_function) + else: + return _clone_functional_model( + model, input_tensors=input_tensors, layer_fn=clone_function) # "Clone" a subclassed model by reseting all of the attributes. diff --git a/tensorflow/python/keras/models_test.py b/tensorflow/python/keras/models_test.py index 0ece5ac69eb..12d1c39f100 100644 --- a/tensorflow/python/keras/models_test.py +++ b/tensorflow/python/keras/models_test.py @@ -245,6 +245,28 @@ class TestModelCloning(keras_parameterized.TestCase): loss = model.train_on_batch(x, y) self.assertEqual(float(loss), 0.) + def test_clone_rnn(self): + # Test cloning a model with multiple cells in an RNN. This exercises a + # few "fancier" features such as the `Bidrectional` wrapper and + # `StackedRNNCells` under the hood. + inputs = keras.Input(shape=(3, 3)) + cells = [ + keras.layers.LSTMCell( + units=32, + enable_caching_device=True, + implementation=2, + activation='relu')] + rnn = keras.layers.RNN(cells, return_sequences=True) + outputs = keras.layers.Bidirectional(rnn)(inputs) + outputs = keras.layers.Dense( + 12, activation='softmax', name='scores')(outputs) + model = keras.Model(inputs=inputs, outputs=outputs) + model.compile( + loss=keras.losses.CategoricalCrossentropy(), + optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.01), + metrics=['accuracy']) + keras.models.clone_model(model) + def test_model_cloning_invalid_use_cases(self): seq_model = keras.models.Sequential() seq_model.add(keras.layers.Dense(4, input_shape=(4,))) diff --git a/tensorflow/python/keras/saving/save.py b/tensorflow/python/keras/saving/save.py index d4749fcb4e8..ef7f6996071 100644 --- a/tensorflow/python/keras/saving/save.py +++ b/tensorflow/python/keras/saving/save.py @@ -148,8 +148,9 @@ def save_model(model, hdf5_format.save_model_to_hdf5( model, filepath, overwrite, include_optimizer) else: - saved_model_save.save(model, filepath, overwrite, include_optimizer, - signatures, options, save_traces) + with generic_utils.SharedObjectSavingScope(): + saved_model_save.save(model, filepath, overwrite, include_optimizer, + signatures, options, save_traces) @keras_export('keras.models.load_model') @@ -194,17 +195,18 @@ def load_model(filepath, custom_objects=None, compile=True, options=None): # py ImportError: if loading from an hdf5 file and h5py is not available. IOError: In case of an invalid savefile. """ - with generic_utils.CustomObjectScope(custom_objects or {}): - with load_context.load_context(options): - if (h5py is not None and - (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))): - return hdf5_format.load_model_from_hdf5(filepath, custom_objects, - compile) + with generic_utils.SharedObjectLoadingScope(): + with generic_utils.CustomObjectScope(custom_objects or {}): + with load_context.load_context(options): + if (h5py is not None and + (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))): + return hdf5_format.load_model_from_hdf5(filepath, custom_objects, + compile) - filepath = path_to_string(filepath) - if isinstance(filepath, six.string_types): - loader_impl.parse_saved_model(filepath) - return saved_model_load.load(filepath, compile, options) + filepath = path_to_string(filepath) + if isinstance(filepath, six.string_types): + loader_impl.parse_saved_model(filepath) + return saved_model_load.load(filepath, compile, options) raise IOError( 'Unable to load model. Filepath is not an hdf5 file (or h5py is not ' diff --git a/tensorflow/python/keras/saving/save_test.py b/tensorflow/python/keras/saving/save_test.py index 00c7bb2d84c..20a779b9b72 100644 --- a/tensorflow/python/keras/saving/save_test.py +++ b/tensorflow/python/keras/saving/save_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import os import shutil import sys @@ -25,12 +26,14 @@ import tempfile from absl.testing import parameterized import numpy as np +from six import string_types from tensorflow.python import keras from tensorflow.python import tf2 from tensorflow.python.eager import context from tensorflow.python.feature_column import feature_column_lib from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.keras import combinations @@ -859,6 +862,125 @@ class TestWholeModelSaving(keras_parameterized.TestCase): self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size), expected) + @combinations.generate(combinations.combine(mode=['eager'])) + def test_shared_objects(self): + class OuterLayer(keras.layers.Layer): + + def __init__(self, inner_layer): + super(OuterLayer, self).__init__() + self.inner_layer = inner_layer + + def call(self, inputs): + return self.inner_layer(inputs) + + def get_config(self): + return { + 'inner_layer': generic_utils.serialize_keras_object( + self.inner_layer) + } + + @classmethod + def from_config(cls, config): + return cls(generic_utils.deserialize_keras_object( + config['inner_layer'])) + + class InnerLayer(keras.layers.Layer): + + def __init__(self): + super(InnerLayer, self).__init__() + self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32) + + def call(self, inputs): + return self.v + inputs + + @classmethod + def from_config(cls, config): + return cls() + + # Create a model with 2 output layers that share the same inner layer. + inner_layer = InnerLayer() + outer_layer_1 = OuterLayer(inner_layer) + outer_layer_2 = OuterLayer(inner_layer) + input_ = keras.Input(shape=(1,)) + model = keras.Model( + inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)]) + + # Changes to the shared layer should affect both outputs. + model.layers[1].inner_layer.v.assign(5) + self.assertAllEqual(model(1), [6.0, 6.0]) + model.layers[1].inner_layer.v.assign(3) + self.assertAllEqual(model(1), [4.0, 4.0]) + + # After loading, changes to the shared layer should still affect both + # outputs. + def _do_assertions(loaded): + loaded.layers[1].inner_layer.v.assign(5) + self.assertAllEqual(loaded(1), [6.0, 6.0]) + loaded.layers[1].inner_layer.v.assign(3) + self.assertAllEqual(loaded(1), [4.0, 4.0]) + loaded.layers[2].inner_layer.v.assign(5) + self.assertAllEqual(loaded(1), [6.0, 6.0]) + loaded.layers[2].inner_layer.v.assign(3) + self.assertAllEqual(loaded(1), [4.0, 4.0]) + + # We'd like to make sure we only attach shared object IDs when strictly + # necessary, so we'll recursively traverse the generated config to count + # whether we have the exact number we expect. + def _get_all_keys_recursive(dict_or_iterable): + if isinstance(dict_or_iterable, dict): + for key in dict_or_iterable.keys(): + yield key + for key in _get_all_keys_recursive(dict_or_iterable.values()): + yield key + elif isinstance(dict_or_iterable, string_types): + return + else: + try: + for item in dict_or_iterable: + for key in _get_all_keys_recursive(item): + yield key + # Not an iterable or dictionary + except TypeError: + return + + with generic_utils.CustomObjectScope({ + 'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}): + + # Test saving and loading to disk + save_format = testing_utils.get_save_format() + saved_model_dir = self._save_model_dir() + keras.models.save_model(model, saved_model_dir, save_format=save_format) + loaded = keras.models.load_model(saved_model_dir) + _do_assertions(loaded) + + # Test recreating directly from config + config = model.get_config() + key_count = collections.Counter(_get_all_keys_recursive(config)) + self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2) + loaded = keras.Model.from_config(config) + _do_assertions(loaded) + + @combinations.generate(combinations.combine(mode=['eager'])) + def test_shared_objects_wrapper(self): + """Tests that shared layers wrapped with `Wrapper` restore correctly.""" + input_ = keras.Input(shape=(1,)) + unwrapped = keras.layers.Layer(name='unwrapped') + wrapped = keras.layers.Wrapper(unwrapped, name='wrapped') + model = keras.Model(inputs=input_, + outputs=[unwrapped(input_), wrapped(input_)]) + + # Test recreating directly from config + config = model.get_config() + loaded = keras.Model.from_config(config) + self.assertIs(loaded.layers[1], loaded.layers[2].layer) + + # Test saving and loading to disk + save_format = testing_utils.get_save_format() + saved_model_dir = self._save_model_dir() + keras.models.save_model(model, saved_model_dir, save_format=save_format) + loaded = keras.models.load_model(saved_model_dir) + self.assertIs(loaded.layers[1], loaded.layers[2].layer) + # Factory functions to create models that will be serialized inside a Network. def _make_graph_network(input_size, output_size): diff --git a/tensorflow/python/keras/saving/saved_model/layer_serialization.py b/tensorflow/python/keras/saving/saved_model/layer_serialization.py index e2776bc70be..3f59a8ee726 100644 --- a/tensorflow/python/keras/saving/saved_model/layer_serialization.py +++ b/tensorflow/python/keras/saving/saved_model/layer_serialization.py @@ -46,7 +46,6 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver): # TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once # the python config serialization has caught up. metadata = dict( - class_name=generic_utils.get_registered_name(type(self.obj)), name=self.obj.name, trainable=self.obj.trainable, expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access @@ -56,7 +55,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver): must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access ) - metadata.update(get_config(self.obj)) + metadata.update(get_serialized(self.obj)) if self.obj.input_spec is not None: # Layer's input_spec has already been type-checked in the property setter. metadata['input_spec'] = nest.map_structure( @@ -110,16 +109,12 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver): # TODO(kathywu): Move serialization utils (and related utils from # generic_utils.py) to a separate file. -def get_config(obj): +def get_serialized(obj): with generic_utils.skip_failed_serialization(): # Store the config dictionary, which may be used when reviving the object. # When loading, the program will attempt to revive the object from config, # and if that fails, the object will be revived from the SavedModel. - config = generic_utils.serialize_keras_object(obj)['config'] - - if config is not None: - return {'config': config} - return {} + return generic_utils.serialize_keras_object(obj) class InputLayerSavedModelSaver(base_serialization.SavedModelSaver): diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index 217b124b97c..fc34bf39573 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -493,13 +493,15 @@ class KerasObjectLoader(object): # found. class_name = metadata.get('class_name') config = metadata.get('config') + shared_object_id = metadata.get('shared_object_id') must_restore_from_config = metadata.get('must_restore_from_config') if not generic_utils.validate_config(config): return None try: obj = layers_module.deserialize( - generic_utils.serialize_keras_class_and_config(class_name, config)) + generic_utils.serialize_keras_class_and_config( + class_name, config, shared_object_id=shared_object_id)) except ValueError: if must_restore_from_config: raise RuntimeError( diff --git a/tensorflow/python/keras/saving/saved_model/metric_serialization.py b/tensorflow/python/keras/saving/saved_model/metric_serialization.py index fda341d30b2..e2b6d3648cf 100644 --- a/tensorflow/python/keras/saving/saved_model/metric_serialization.py +++ b/tensorflow/python/keras/saving/saved_model/metric_serialization.py @@ -36,7 +36,7 @@ class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver): class_name=generic_utils.get_registered_name(type(self.obj)), name=self.obj.name, dtype=self.obj.dtype) - metadata.update(layer_serialization.get_config(self.obj)) + metadata.update(layer_serialization.get_serialized(self.obj)) if self.obj._build_input_shape is not None: # pylint: disable=protected-access metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access return metadata diff --git a/tensorflow/python/keras/utils/generic_utils.py b/tensorflow/python/keras/utils/generic_utils.py index ecf382413ad..89aeaf4ab28 100644 --- a/tensorflow/python/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/utils/generic_utils.py @@ -24,8 +24,10 @@ import marshal import os import re import sys +import threading import time import types as python_types +import weakref import numpy as np import six @@ -110,9 +112,235 @@ def get_custom_objects(): return _GLOBAL_CUSTOM_OBJECTS -def serialize_keras_class_and_config(cls_name, cls_config): +# Store a unique, per-object ID for shared objects. +# +# We store a unique ID for each object so that we may, at loading time, +# re-create the network properly. Without this ID, we would have no way of +# determining whether a config is a description of a new object that +# should be created or is merely a reference to an already-created object. +SHARED_OBJECT_KEY = 'shared_object_id' + + +SHARED_OBJECT_DISABLED = threading.local() +SHARED_OBJECT_LOADING = threading.local() +SHARED_OBJECT_SAVING = threading.local() + + +# Attributes on the threadlocal variable must be set per-thread, thus we +# cannot initialize these globally. Instead, we have accessor functions with +# default values. +def _shared_object_disabled(): + """Get whether shared object handling is disabled in a threadsafe manner.""" + return getattr(SHARED_OBJECT_DISABLED, 'disabled', False) + + +def _shared_object_loading_scope(): + """Get the current shared object saving scope in a threadsafe manner.""" + return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope()) + + +def _shared_object_saving_scope(): + """Get the current shared object saving scope in a threadsafe manner.""" + return getattr(SHARED_OBJECT_SAVING, 'scope', None) + + +class DisableSharedObjectScope(object): + """A context manager for disabling handling of shared objects. + + Disables shared object handling for both saving and loading. + + Created primarily for use with `clone_model`, which does extra surgery that + is incompatible with shared objects. + """ + + def __enter__(self): + SHARED_OBJECT_DISABLED.disabled = True + self._orig_loading_scope = _shared_object_loading_scope() + self._orig_saving_scope = _shared_object_saving_scope() + + def __exit__(self, *args, **kwargs): + SHARED_OBJECT_DISABLED.disabled = False + SHARED_OBJECT_LOADING.scope = self._orig_loading_scope + SHARED_OBJECT_SAVING.scope = self._orig_saving_scope + + +class NoopLoadingScope(object): + """The default shared object loading scope. It does nothing. + + Created to simplify serialization code that doesn't care about shared objects + (e.g. when serializing a single object). + """ + + def get(self, unused_object_id): + return None + + def set(self, object_id, obj): + pass + + +class SharedObjectLoadingScope(object): + """A context manager for keeping track of loaded objects. + + During the deserialization process, we may come across objects that are + shared across multiple layers. In order to accurately restore the network + structure to its original state, `SharedObjectLoadingScope` allows us to + re-use shared objects rather than cloning them. + """ + + def __enter__(self): + if _shared_object_disabled(): + return NoopLoadingScope() + + global SHARED_OBJECT_LOADING + SHARED_OBJECT_LOADING.scope = self + self._obj_ids_to_obj = {} + return self + + def get(self, object_id): + """Given a shared object ID, returns a previously instantiated object. + + Args: + object_id: shared object ID to use when attempting to find already-loaded + object. + + Returns: + The object, if we've seen this ID before. Else, `None`. + """ + # Explicitly check for `None` internally to make external calling code a + # bit cleaner. + if object_id is None: + return + return self._obj_ids_to_obj.get(object_id) + + def set(self, object_id, obj): + """Stores an instantiated object for future lookup and sharing.""" + if object_id is None: + return + self._obj_ids_to_obj[object_id] = obj + + def __exit__(self, *args, **kwargs): + global SHARED_OBJECT_LOADING + SHARED_OBJECT_LOADING.scope = NoopLoadingScope() + + +class SharedObjectConfig(dict): + """A configuration container that keeps track of references. + + `SharedObjectConfig` will automatically attach a shared object ID to any + configs which are referenced more than once, allowing for proper shared + object reconstruction at load time. + + In most cases, it would be more proper to subclass something like + `collections.UserDict` or `collections.Mapping` rather than `dict` directly. + Unfortunately, python's json encoder does not support `Mapping`s. This is + important functionality to retain, since we are dealing with serialization. + + We should be safe to subclass `dict` here, since we aren't actually + overriding any core methods, only augmenting with a new one for reference + counting. + """ + + def __init__(self, base_config, object_id, **kwargs): + self.ref_count = 1 + self.object_id = object_id + super(SharedObjectConfig, self).__init__(base_config, **kwargs) + + def increment_ref_count(self): + # As soon as we've seen the object more than once, we want to attach the + # shared object ID. This allows us to only attach the shared object ID when + # it's strictly necessary, making backwards compatibility breakage less + # likely. + if self.ref_count == 1: + self[SHARED_OBJECT_KEY] = self.object_id + self.ref_count += 1 + + +class SharedObjectSavingScope(object): + """Keeps track of shared object configs when serializing.""" + + def __enter__(self): + if _shared_object_disabled(): + return None + + global SHARED_OBJECT_SAVING + + # Serialization can happen at a number of layers for a number of reasons. + # We may end up with a case where we're opening a saving scope within + # another saving scope. In that case, we'd like to use the outermost scope + # available and ignore inner scopes, since there is not (yet) a reasonable + # use case for having these nested and distinct. + if _shared_object_saving_scope() is not None: + self._passthrough = True + return _shared_object_saving_scope() + else: + self._passthrough = False + + SHARED_OBJECT_SAVING.scope = self + self._shared_objects_config = weakref.WeakKeyDictionary() + self._next_id = 0 + return self + + def get_config(self, obj): + """Gets a `SharedObjectConfig` if one has already been seen for `obj`. + + Args: + obj: The object for which to retrieve the `SharedObjectConfig`. + + Returns: + The SharedObjectConfig for a given object, if already seen. Else, + `None`. + """ + try: + shared_object_config = self._shared_objects_config[obj] + except (TypeError, KeyError): + # If the object is unhashable (e.g. a subclass of `AbstractBaseClass` + # that has not overridden `__hash__`), a `TypeError` will be thrown. + # We'll just continue on without shared object support. + return None + shared_object_config.increment_ref_count() + return shared_object_config + + def create_config(self, base_config, obj): + """Create a new SharedObjectConfig for a given object.""" + shared_object_config = SharedObjectConfig(base_config, self._next_id) + self._next_id += 1 + try: + self._shared_objects_config[obj] = shared_object_config + except TypeError: + # If the object is unhashable (e.g. a subclass of `AbstractBaseClass` + # that has not overridden `__hash__`), a `TypeError` will be thrown. + # We'll just continue on without shared object support. + pass + return shared_object_config + + def __exit__(self, *args, **kwargs): + if not getattr(self, '_passthrough', False): + global SHARED_OBJECT_SAVING + SHARED_OBJECT_SAVING.scope = None + + +def serialize_keras_class_and_config( + cls_name, cls_config, obj=None, shared_object_id=None): """Returns the serialization of the class with the given config.""" - return {'class_name': cls_name, 'config': cls_config} + base_config = {'class_name': cls_name, 'config': cls_config} + + # We call `serialize_keras_class_and_config` for some branches of the load + # path. In that case, we may already have a shared object ID we'd like to + # retain. + if shared_object_id is not None: + base_config[SHARED_OBJECT_KEY] = shared_object_id + + # If we have an active `SharedObjectSavingScope`, check whether we've already + # serialized this config. If so, just use that config. This will store an + # extra ID field in the config, allowing us to re-create the shared object + # relationship at load time. + if _shared_object_saving_scope() is not None and obj is not None: + shared_object_config = _shared_object_saving_scope().get_config(obj) + if shared_object_config is None: + return _shared_object_saving_scope().create_config(base_config, obj) + return shared_object_config + + return base_config @keras_export('keras.utils.register_keras_serializable') @@ -234,7 +462,19 @@ def get_registered_object(name, custom_objects=None, module_objects=None): @keras_export('keras.utils.serialize_keras_object') def serialize_keras_object(instance): - """Serialize a Keras object into a JSON-compatible representation.""" + """Serialize a Keras object into a JSON-compatible representation. + + Calls to `serialize_keras_object` while underneath the + `SharedObjectSavingScope` context manager will cause any objects re-used + across multiple layers to be saved with a special shared object ID. This + allows the network to be re-created properly during deserialization. + + Args: + instance: The object to serialize. + + Returns: + A dict-like, JSON-compatible representation of the object's config. + """ _, instance = tf_decorator.unwrap(instance) if instance is None: return None @@ -265,7 +505,8 @@ def serialize_keras_object(instance): serialization_config[key] = item name = get_registered_name(instance.__class__) - return serialize_keras_class_and_config(name, serialization_config) + return serialize_keras_class_and_config( + name, serialization_config, instance) if hasattr(instance, '__name__'): return get_registered_name(instance) raise ValueError('Cannot serialize', instance) @@ -286,8 +527,9 @@ def class_and_config_for_serialized_keras_object( custom_objects=None, printable_module_name='object'): """Returns the class name and config for a serialized keras object.""" - if (not isinstance(config, dict) or 'class_name' not in config or - 'config' not in config): + if (not isinstance(config, dict) + or 'class_name' not in config + or 'config' not in config): raise ValueError('Improper config format: ' + str(config)) class_name = config['class_name'] @@ -341,7 +583,24 @@ def deserialize_keras_object(identifier, module_objects=None, custom_objects=None, printable_module_name='object'): - """Turns the serialized form of a Keras object back into an actual object.""" + """Turns the serialized form of a Keras object back into an actual object. + + Calls to `deserialize_keras_object` while underneath the + `SharedObjectLoadingScope` context manager will cause any already-seen shared + objects to be returned as-is rather than creating a new object. + + Args: + identifier: the serialized form of the object. + module_objects: A dictionary of custom objects to look the name up in. + Generally, module_objects is provided by midlevel library implementers. + custom_objects: A dictionary of custom objects to look the name up in. + Generally, custom_objects is provided by the user. + printable_module_name: A human-readable string representing the type of the + object. Printed in case of exception. + + Returns: + The deserialized object. + """ if identifier is None: return None @@ -351,25 +610,39 @@ def deserialize_keras_object(identifier, (cls, cls_config) = class_and_config_for_serialized_keras_object( config, module_objects, custom_objects, printable_module_name) + # If this object has already been loaded (i.e. it's shared between multiple + # objects), return the already-loaded object. + shared_object_id = config.get(SHARED_OBJECT_KEY) + shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none + if shared_object is not None: + return shared_object + if hasattr(cls, 'from_config'): arg_spec = tf_inspect.getfullargspec(cls.from_config) custom_objects = custom_objects or {} if 'custom_objects' in arg_spec.args: - return cls.from_config( + deserialized_obj = cls.from_config( cls_config, custom_objects=dict( list(_GLOBAL_CUSTOM_OBJECTS.items()) + list(custom_objects.items()))) - with CustomObjectScope(custom_objects): - return cls.from_config(cls_config) + else: + with CustomObjectScope(custom_objects): + deserialized_obj = cls.from_config(cls_config) else: # Then `cls` may be a function returning a class. # in this case by convention `config` holds # the kwargs of the function. custom_objects = custom_objects or {} with CustomObjectScope(custom_objects): - return cls(**cls_config) + deserialized_obj = cls(**cls_config) + + # Add object to shared objects, in case we find it referenced again. + _shared_object_loading_scope().set(shared_object_id, deserialized_obj) + + return deserialized_obj + elif isinstance(identifier, six.string_types): object_name = identifier if custom_objects and object_name in custom_objects: diff --git a/tensorflow/python/keras/utils/generic_utils_test.py b/tensorflow/python/keras/utils/generic_utils_test.py index 2dc2952d328..dd28b17cb7d 100644 --- a/tensorflow/python/keras/utils/generic_utils_test.py +++ b/tensorflow/python/keras/utils/generic_utils_test.py @@ -23,6 +23,7 @@ from functools import partial import numpy as np from tensorflow.python import keras +from tensorflow.python.keras.utils import generic_utils from tensorflow.python.platform import test @@ -384,5 +385,63 @@ class SliceArraysTest(test.TestCase): [None, None, None]) +# object() alone isn't compatible with WeakKeyDictionary, which we use to +# track shared configs. +class MaybeSharedObject(object): + pass + + +class SharedObjectScopeTest(test.TestCase): + + def test_shared_object_saving_scope_single_object_doesnt_export_id(self): + with generic_utils.SharedObjectSavingScope() as scope: + single_object = MaybeSharedObject() + self.assertIsNone(scope.get_config(single_object)) + single_object_config = scope.create_config({}, single_object) + self.assertIsNotNone(single_object_config) + self.assertNotIn(generic_utils.SHARED_OBJECT_KEY, + single_object_config) + + def test_shared_object_saving_scope_shared_object_exports_id(self): + with generic_utils.SharedObjectSavingScope() as scope: + shared_object = MaybeSharedObject() + self.assertIsNone(scope.get_config(shared_object)) + scope.create_config({}, shared_object) + first_object_config = scope.get_config(shared_object) + second_object_config = scope.get_config(shared_object) + self.assertIn(generic_utils.SHARED_OBJECT_KEY, + first_object_config) + self.assertIn(generic_utils.SHARED_OBJECT_KEY, + second_object_config) + self.assertIs(first_object_config, second_object_config) + + def test_shared_object_loading_scope_noop(self): + # Test that, without a context manager scope, adding configs will do + # nothing. + obj_id = 1 + obj = MaybeSharedObject() + generic_utils._shared_object_loading_scope().set(obj_id, obj) + self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id)) + + def test_shared_object_loading_scope_returns_shared_obj(self): + obj_id = 1 + obj = MaybeSharedObject() + with generic_utils.SharedObjectLoadingScope() as scope: + scope.set(obj_id, obj) + self.assertIs(scope.get(obj_id), obj) + + def test_nested_shared_object_saving_scopes(self): + my_obj = MaybeSharedObject() + with generic_utils.SharedObjectSavingScope() as scope_1: + scope_1.create_config({}, my_obj) + with generic_utils.SharedObjectSavingScope() as scope_2: + # Nesting saving scopes should return the original scope and should + # not clear any objects we're tracking. + self.assertIs(scope_1, scope_2) + self.assertIsNotNone(scope_2.get_config(my_obj)) + self.assertIsNotNone(scope_1.get_config(my_obj)) + self.assertIsNone(generic_utils._shared_object_saving_scope()) + + if __name__ == '__main__': test.main()