Add support for shared objects in keras when (de)serializing.
Previously, objects shared between multiple layers would be duplicated when saving and loading. This commit adds a unique ID for keras objects when serializing that allows us to correctly create only a single instance of shared objects when deserializing. PiperOrigin-RevId: 353123479 Change-Id: I057ba61fe587d5cb97238fb31ca562dc5791cf88
This commit is contained in:
parent
a7d6c05bd6
commit
2d3d381d5f
@ -671,13 +671,12 @@ class Functional(training_lib.Model):
|
||||
Raises:
|
||||
ValueError: In case of improperly formatted config dict.
|
||||
"""
|
||||
with generic_utils.SharedObjectLoadingScope():
|
||||
input_tensors, output_tensors, created_layers = reconstruct_from_config(
|
||||
config, custom_objects)
|
||||
model = cls(inputs=input_tensors, outputs=output_tensors,
|
||||
name=config.get('name'))
|
||||
connect_ancillary_layers(model, created_layers)
|
||||
return model
|
||||
input_tensors, output_tensors, created_layers = reconstruct_from_config(
|
||||
config, custom_objects)
|
||||
model = cls(inputs=input_tensors, outputs=output_tensors,
|
||||
name=config.get('name'))
|
||||
connect_ancillary_layers(model, created_layers)
|
||||
return model
|
||||
|
||||
def _validate_graph_inputs_and_outputs(self):
|
||||
"""Validates the inputs and outputs of a Graph Network."""
|
||||
@ -1347,23 +1346,21 @@ def get_network_config(network, serialize_layer_fn=None):
|
||||
node_conversion_map[node_key] = kept_nodes
|
||||
kept_nodes += 1
|
||||
layer_configs = []
|
||||
for layer in network.layers: # From the earliest layers on.
|
||||
filtered_inbound_nodes = []
|
||||
for original_node_index, node in enumerate(layer._inbound_nodes):
|
||||
node_key = _make_node_key(layer.name, original_node_index)
|
||||
if node_key in network._network_nodes and not node.is_input:
|
||||
# The node is relevant to the model:
|
||||
# add to filtered_inbound_nodes.
|
||||
node_data = node.serialize(_make_node_key, node_conversion_map)
|
||||
filtered_inbound_nodes.append(node_data)
|
||||
|
||||
with generic_utils.SharedObjectSavingScope():
|
||||
for layer in network.layers: # From the earliest layers on.
|
||||
filtered_inbound_nodes = []
|
||||
for original_node_index, node in enumerate(layer._inbound_nodes):
|
||||
node_key = _make_node_key(layer.name, original_node_index)
|
||||
if node_key in network._network_nodes and not node.is_input:
|
||||
# The node is relevant to the model:
|
||||
# add to filtered_inbound_nodes.
|
||||
node_data = node.serialize(_make_node_key, node_conversion_map)
|
||||
filtered_inbound_nodes.append(node_data)
|
||||
|
||||
layer_config = serialize_layer_fn(layer)
|
||||
layer_config['name'] = layer.name
|
||||
layer_config['inbound_nodes'] = filtered_inbound_nodes
|
||||
layer_configs.append(layer_config)
|
||||
config['layers'] = layer_configs
|
||||
layer_config = serialize_layer_fn(layer)
|
||||
layer_config['name'] = layer.name
|
||||
layer_config['inbound_nodes'] = filtered_inbound_nodes
|
||||
layer_configs.append(layer_config)
|
||||
config['layers'] = layer_configs
|
||||
|
||||
# Gather info about inputs and outputs.
|
||||
model_inputs = []
|
||||
|
@ -393,10 +393,6 @@ def clone_model(model, input_tensors=None, clone_function=None):
|
||||
except that it creates new layers (and thus new weights) instead
|
||||
of sharing the weights of the existing layers.
|
||||
|
||||
`clone_model` will not preserve the uniqueness of shared objects within the
|
||||
model (e.g. a single variable attached to two distinct layers will be
|
||||
restored as two separate variables).
|
||||
|
||||
Args:
|
||||
model: Instance of `Model`
|
||||
(could be a functional model or a Sequential model).
|
||||
|
@ -148,9 +148,8 @@ def save_model(model,
|
||||
hdf5_format.save_model_to_hdf5(
|
||||
model, filepath, overwrite, include_optimizer)
|
||||
else:
|
||||
with generic_utils.SharedObjectSavingScope():
|
||||
saved_model_save.save(model, filepath, overwrite, include_optimizer,
|
||||
signatures, options, save_traces)
|
||||
saved_model_save.save(model, filepath, overwrite, include_optimizer,
|
||||
signatures, options, save_traces)
|
||||
|
||||
|
||||
@keras_export('keras.models.load_model')
|
||||
@ -195,18 +194,17 @@ def load_model(filepath, custom_objects=None, compile=True, options=None): # py
|
||||
ImportError: if loading from an hdf5 file and h5py is not available.
|
||||
IOError: In case of an invalid savefile.
|
||||
"""
|
||||
with generic_utils.SharedObjectLoadingScope():
|
||||
with generic_utils.CustomObjectScope(custom_objects or {}):
|
||||
with load_context.load_context(options):
|
||||
if (h5py is not None and
|
||||
(isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
|
||||
return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
|
||||
compile)
|
||||
with generic_utils.CustomObjectScope(custom_objects or {}):
|
||||
with load_context.load_context(options):
|
||||
if (h5py is not None and
|
||||
(isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
|
||||
return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
|
||||
compile)
|
||||
|
||||
filepath = path_to_string(filepath)
|
||||
if isinstance(filepath, six.string_types):
|
||||
loader_impl.parse_saved_model(filepath)
|
||||
return saved_model_load.load(filepath, compile, options)
|
||||
filepath = path_to_string(filepath)
|
||||
if isinstance(filepath, six.string_types):
|
||||
loader_impl.parse_saved_model(filepath)
|
||||
return saved_model_load.load(filepath, compile, options)
|
||||
|
||||
raise IOError(
|
||||
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
@ -26,14 +25,12 @@ import tempfile
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from six import string_types
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.feature_column import feature_column_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.keras import combinations
|
||||
@ -862,125 +859,6 @@ class TestWholeModelSaving(keras_parameterized.TestCase):
|
||||
self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
|
||||
expected)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def test_shared_objects(self):
|
||||
class OuterLayer(keras.layers.Layer):
|
||||
|
||||
def __init__(self, inner_layer):
|
||||
super(OuterLayer, self).__init__()
|
||||
self.inner_layer = inner_layer
|
||||
|
||||
def call(self, inputs):
|
||||
return self.inner_layer(inputs)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'inner_layer': generic_utils.serialize_keras_object(
|
||||
self.inner_layer)
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(generic_utils.deserialize_keras_object(
|
||||
config['inner_layer']))
|
||||
|
||||
class InnerLayer(keras.layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(InnerLayer, self).__init__()
|
||||
self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32)
|
||||
|
||||
def call(self, inputs):
|
||||
return self.v + inputs
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls()
|
||||
|
||||
# Create a model with 2 output layers that share the same inner layer.
|
||||
inner_layer = InnerLayer()
|
||||
outer_layer_1 = OuterLayer(inner_layer)
|
||||
outer_layer_2 = OuterLayer(inner_layer)
|
||||
input_ = keras.Input(shape=(1,))
|
||||
model = keras.Model(
|
||||
inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)])
|
||||
|
||||
# Changes to the shared layer should affect both outputs.
|
||||
model.layers[1].inner_layer.v.assign(5)
|
||||
self.assertAllEqual(model(1), [6.0, 6.0])
|
||||
model.layers[1].inner_layer.v.assign(3)
|
||||
self.assertAllEqual(model(1), [4.0, 4.0])
|
||||
|
||||
# After loading, changes to the shared layer should still affect both
|
||||
# outputs.
|
||||
def _do_assertions(loaded):
|
||||
loaded.layers[1].inner_layer.v.assign(5)
|
||||
self.assertAllEqual(loaded(1), [6.0, 6.0])
|
||||
loaded.layers[1].inner_layer.v.assign(3)
|
||||
self.assertAllEqual(loaded(1), [4.0, 4.0])
|
||||
loaded.layers[2].inner_layer.v.assign(5)
|
||||
self.assertAllEqual(loaded(1), [6.0, 6.0])
|
||||
loaded.layers[2].inner_layer.v.assign(3)
|
||||
self.assertAllEqual(loaded(1), [4.0, 4.0])
|
||||
|
||||
# We'd like to make sure we only attach shared object IDs when strictly
|
||||
# necessary, so we'll recursively traverse the generated config to count
|
||||
# whether we have the exact number we expect.
|
||||
def _get_all_keys_recursive(dict_or_iterable):
|
||||
if isinstance(dict_or_iterable, dict):
|
||||
for key in dict_or_iterable.keys():
|
||||
yield key
|
||||
for key in _get_all_keys_recursive(dict_or_iterable.values()):
|
||||
yield key
|
||||
elif isinstance(dict_or_iterable, string_types):
|
||||
return
|
||||
else:
|
||||
try:
|
||||
for item in dict_or_iterable:
|
||||
for key in _get_all_keys_recursive(item):
|
||||
yield key
|
||||
# Not an iterable or dictionary
|
||||
except TypeError:
|
||||
return
|
||||
|
||||
with generic_utils.CustomObjectScope({
|
||||
'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}):
|
||||
|
||||
# Test saving and loading to disk
|
||||
save_format = testing_utils.get_save_format()
|
||||
saved_model_dir = self._save_model_dir()
|
||||
keras.models.save_model(model, saved_model_dir, save_format=save_format)
|
||||
loaded = keras.models.load_model(saved_model_dir)
|
||||
_do_assertions(loaded)
|
||||
|
||||
# Test recreating directly from config
|
||||
config = model.get_config()
|
||||
key_count = collections.Counter(_get_all_keys_recursive(config))
|
||||
self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2)
|
||||
loaded = keras.Model.from_config(config)
|
||||
_do_assertions(loaded)
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def test_shared_objects_wrapper(self):
|
||||
"""Tests that shared layers wrapped with `Wrapper` restore correctly."""
|
||||
input_ = keras.Input(shape=(1,))
|
||||
unwrapped = keras.layers.Layer(name='unwrapped')
|
||||
wrapped = keras.layers.Wrapper(unwrapped, name='wrapped')
|
||||
model = keras.Model(inputs=input_,
|
||||
outputs=[unwrapped(input_), wrapped(input_)])
|
||||
|
||||
# Test recreating directly from config
|
||||
config = model.get_config()
|
||||
loaded = keras.Model.from_config(config)
|
||||
self.assertIs(loaded.layers[1], loaded.layers[2].layer)
|
||||
|
||||
# Test saving and loading to disk
|
||||
save_format = testing_utils.get_save_format()
|
||||
saved_model_dir = self._save_model_dir()
|
||||
keras.models.save_model(model, saved_model_dir, save_format=save_format)
|
||||
loaded = keras.models.load_model(saved_model_dir)
|
||||
self.assertIs(loaded.layers[1], loaded.layers[2].layer)
|
||||
|
||||
|
||||
# Factory functions to create models that will be serialized inside a Network.
|
||||
def _make_graph_network(input_size, output_size):
|
||||
|
@ -46,6 +46,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
|
||||
# the python config serialization has caught up.
|
||||
metadata = dict(
|
||||
class_name=generic_utils.get_registered_name(type(self.obj)),
|
||||
name=self.obj.name,
|
||||
trainable=self.obj.trainable,
|
||||
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
|
||||
@ -55,7 +56,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||
must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
|
||||
)
|
||||
|
||||
metadata.update(get_serialized(self.obj))
|
||||
metadata.update(get_config(self.obj))
|
||||
if self.obj.input_spec is not None:
|
||||
# Layer's input_spec has already been type-checked in the property setter.
|
||||
metadata['input_spec'] = nest.map_structure(
|
||||
@ -109,12 +110,16 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||
|
||||
# TODO(kathywu): Move serialization utils (and related utils from
|
||||
# generic_utils.py) to a separate file.
|
||||
def get_serialized(obj):
|
||||
def get_config(obj):
|
||||
with generic_utils.skip_failed_serialization():
|
||||
# Store the config dictionary, which may be used when reviving the object.
|
||||
# When loading, the program will attempt to revive the object from config,
|
||||
# and if that fails, the object will be revived from the SavedModel.
|
||||
return generic_utils.serialize_keras_object(obj)
|
||||
config = generic_utils.serialize_keras_object(obj)['config']
|
||||
|
||||
if config is not None:
|
||||
return {'config': config}
|
||||
return {}
|
||||
|
||||
|
||||
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||
|
@ -492,15 +492,13 @@ class KerasObjectLoader(object):
|
||||
# found.
|
||||
class_name = metadata.get('class_name')
|
||||
config = metadata.get('config')
|
||||
shared_object_id = metadata.get('shared_object_id')
|
||||
must_restore_from_config = metadata.get('must_restore_from_config')
|
||||
if not generic_utils.validate_config(config):
|
||||
return None
|
||||
|
||||
try:
|
||||
obj = layers_module.deserialize(
|
||||
generic_utils.serialize_keras_class_and_config(
|
||||
class_name, config, shared_object_id=shared_object_id))
|
||||
generic_utils.serialize_keras_class_and_config(class_name, config))
|
||||
except ValueError:
|
||||
if must_restore_from_config:
|
||||
raise RuntimeError(
|
||||
|
@ -36,7 +36,7 @@ class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver):
|
||||
class_name=generic_utils.get_registered_name(type(self.obj)),
|
||||
name=self.obj.name,
|
||||
dtype=self.obj.dtype)
|
||||
metadata.update(layer_serialization.get_serialized(self.obj))
|
||||
metadata.update(layer_serialization.get_config(self.obj))
|
||||
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
|
||||
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
|
||||
return metadata
|
||||
|
@ -24,10 +24,8 @@ import marshal
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types as python_types
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
@ -112,205 +110,9 @@ def get_custom_objects():
|
||||
return _GLOBAL_CUSTOM_OBJECTS
|
||||
|
||||
|
||||
# Store a unique, per-object ID for shared objects.
|
||||
#
|
||||
# We store a unique ID for each object so that we may, at loading time,
|
||||
# re-create the network properly. Without this ID, we would have no way of
|
||||
# determining whether a config is a description of a new object that
|
||||
# should be created or is merely a reference to an already-created object.
|
||||
SHARED_OBJECT_KEY = 'shared_object_id'
|
||||
|
||||
|
||||
class NoopLoadingScope(object):
|
||||
"""The default shared object loading scope. It does nothing.
|
||||
|
||||
Created to simplify serialization code that doesn't care about shared objects
|
||||
(e.g. when serializing a single object).
|
||||
"""
|
||||
|
||||
def get(self, unused_object_id):
|
||||
return None
|
||||
|
||||
def set(self, object_id, obj):
|
||||
pass
|
||||
|
||||
|
||||
SHARED_OBJECT_LOADING = threading.local()
|
||||
|
||||
|
||||
def _shared_object_loading_scope():
|
||||
"""Get the current shared object saving scope in a threadsafe manner.
|
||||
|
||||
Attributes on the threadlocal variable must be set per-thread, thus we
|
||||
cannot initialize these globally.
|
||||
|
||||
Returns:
|
||||
A SharedObjectLoadingScope or NoopLoadingScope object.
|
||||
"""
|
||||
return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
|
||||
|
||||
|
||||
class SharedObjectLoadingScope(object):
|
||||
"""A context manager for keeping track of loaded objects.
|
||||
|
||||
During the deserialization process, we may come across objects that are
|
||||
shared across multiple layers. In order to accurately restore the network
|
||||
structure to its original state, `SharedObjectLoadingScope` allows us to
|
||||
re-use shared objects rather than cloning them.
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
global SHARED_OBJECT_LOADING
|
||||
|
||||
SHARED_OBJECT_LOADING.scope = self
|
||||
self._obj_ids_to_obj = {}
|
||||
return self
|
||||
|
||||
def get(self, object_id):
|
||||
"""Given a shared object ID, returns a previously instantiated object.
|
||||
|
||||
Args:
|
||||
object_id: shared object ID to use when attempting to find already-loaded
|
||||
object.
|
||||
|
||||
Returns:
|
||||
The object, if we've seen this ID before. Else, `None`.
|
||||
"""
|
||||
# Explicitly check for `None` internally to make external calling code a
|
||||
# bit cleaner.
|
||||
if object_id is None:
|
||||
return
|
||||
return self._obj_ids_to_obj.get(object_id)
|
||||
|
||||
def set(self, object_id, obj):
|
||||
"""Stores an instantiated object for future lookup and sharing."""
|
||||
if object_id is None:
|
||||
return
|
||||
self._obj_ids_to_obj[object_id] = obj
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
global SHARED_OBJECT_LOADING
|
||||
SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
|
||||
|
||||
|
||||
SHARED_OBJECT_SAVING = threading.local()
|
||||
|
||||
|
||||
def _shared_object_saving_scope():
|
||||
"""Get the current shared object saving scope in a threadsafe manner.
|
||||
|
||||
Attributes on the threadlocal variable must be set per-thread, thus we
|
||||
cannot initialize these globally.
|
||||
|
||||
Returns:
|
||||
A SharedObjectSavingScope object or None.
|
||||
"""
|
||||
return getattr(SHARED_OBJECT_SAVING, 'scope', None)
|
||||
|
||||
|
||||
class SharedObjectConfig(dict):
|
||||
"""A configuration container that keeps track of references.
|
||||
|
||||
`SharedObjectConfig` will automatically attach a shared object ID to any
|
||||
configs which are referenced more than once, allowing for proper shared
|
||||
object reconstruction at load time.
|
||||
|
||||
In most cases, it would be more proper to subclass something like
|
||||
`collections.UserDict` or `collections.Mapping` rather than `dict` directly.
|
||||
Unfortunately, python's json encoder does not support `Mapping`s. This is
|
||||
important functionality to retain, since we are dealing with serialization.
|
||||
|
||||
We should be safe to subclass `dict` here, since we aren't actually
|
||||
overriding any core methods, only augmenting with a new one for reference
|
||||
counting.
|
||||
"""
|
||||
|
||||
def __init__(self, base_config, object_id, **kwargs):
|
||||
self.ref_count = 1
|
||||
self.object_id = object_id
|
||||
super(SharedObjectConfig, self).__init__(base_config, **kwargs)
|
||||
|
||||
def increment_ref_count(self):
|
||||
# As soon as we've seen the object more than once, we want to attach the
|
||||
# shared object ID. This allows us to only attach the shared object ID when
|
||||
# it's strictly necessary, making backwards compatibility breakage less
|
||||
# likely.
|
||||
if self.ref_count == 1:
|
||||
self[SHARED_OBJECT_KEY] = self.object_id
|
||||
self.ref_count += 1
|
||||
|
||||
|
||||
class SharedObjectSavingScope(object):
|
||||
"""Keeps track of shared object configs when serializing."""
|
||||
|
||||
def __enter__(self):
|
||||
global SHARED_OBJECT_SAVING
|
||||
|
||||
# Serialization can happen at a number of layers for a number of reasons.
|
||||
# We may end up with a case where we're opening a saving scope within
|
||||
# another saving scope. In that case, we'd like to use the outermost scope
|
||||
# available and ignore inner scopes, since there is not (yet) a reasonable
|
||||
# use case for having these nested and distinct.
|
||||
if _shared_object_saving_scope() is not None:
|
||||
self._passthrough = True
|
||||
return _shared_object_saving_scope()
|
||||
else:
|
||||
self._passthrough = False
|
||||
|
||||
SHARED_OBJECT_SAVING.scope = self
|
||||
self._shared_objects_config = weakref.WeakKeyDictionary()
|
||||
self._next_id = 0
|
||||
return self
|
||||
|
||||
def get_config(self, obj):
|
||||
"""Gets a `SharedObjectConfig` if one has already been seen for `obj`.
|
||||
|
||||
Args:
|
||||
obj: The object for which to retrieve the `SharedObjectConfig`.
|
||||
|
||||
Returns:
|
||||
The SharedObjectConfig for a given object, if already seen. Else,
|
||||
`None`.
|
||||
"""
|
||||
if obj in self._shared_objects_config:
|
||||
shared_object_config = self._shared_objects_config[obj]
|
||||
shared_object_config.increment_ref_count()
|
||||
return shared_object_config
|
||||
|
||||
def create_config(self, base_config, obj):
|
||||
shared_object_config = SharedObjectConfig(base_config, self._next_id)
|
||||
self._next_id += 1
|
||||
self._shared_objects_config[obj] = shared_object_config
|
||||
return shared_object_config
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
if not self._passthrough:
|
||||
global SHARED_OBJECT_SAVING
|
||||
SHARED_OBJECT_SAVING.scope = None
|
||||
|
||||
|
||||
def serialize_keras_class_and_config(
|
||||
cls_name, cls_config, obj=None, shared_object_id=None):
|
||||
def serialize_keras_class_and_config(cls_name, cls_config):
|
||||
"""Returns the serialization of the class with the given config."""
|
||||
base_config = {'class_name': cls_name, 'config': cls_config}
|
||||
|
||||
# We call `serialize_keras_class_and_config` for some branches of the load
|
||||
# path. In that case, we may already have a shared object ID we'd like to
|
||||
# retain.
|
||||
if shared_object_id is not None:
|
||||
base_config[SHARED_OBJECT_KEY] = shared_object_id
|
||||
|
||||
# If we have an active `SharedObjectSavingScope`, check whether we've already
|
||||
# serialized this config. If so, just use that config. This will store an
|
||||
# extra ID field in the config, allowing us to re-create the shared object
|
||||
# relationship at load time.
|
||||
if _shared_object_saving_scope() is not None and obj is not None:
|
||||
shared_object_config = _shared_object_saving_scope().get_config(obj)
|
||||
if shared_object_config is None:
|
||||
return _shared_object_saving_scope().create_config(base_config, obj)
|
||||
return shared_object_config
|
||||
|
||||
return base_config
|
||||
return {'class_name': cls_name, 'config': cls_config}
|
||||
|
||||
|
||||
@keras_export('keras.utils.register_keras_serializable')
|
||||
@ -432,19 +234,7 @@ def get_registered_object(name, custom_objects=None, module_objects=None):
|
||||
|
||||
@keras_export('keras.utils.serialize_keras_object')
|
||||
def serialize_keras_object(instance):
|
||||
"""Serialize a Keras object into a JSON-compatible representation.
|
||||
|
||||
Calls to `serialize_keras_object` while underneath the
|
||||
`SharedObjectSavingScope` context manager will cause any objects re-used
|
||||
across multiple layers to be saved with a special shared object ID. This
|
||||
allows the network to be re-created properly during deserialization.
|
||||
|
||||
Args:
|
||||
instance: The object to serialize.
|
||||
|
||||
Returns:
|
||||
A dict-like, JSON-compatible representation of the object's config.
|
||||
"""
|
||||
"""Serialize a Keras object into a JSON-compatible representation."""
|
||||
_, instance = tf_decorator.unwrap(instance)
|
||||
if instance is None:
|
||||
return None
|
||||
@ -475,8 +265,7 @@ def serialize_keras_object(instance):
|
||||
serialization_config[key] = item
|
||||
|
||||
name = get_registered_name(instance.__class__)
|
||||
return serialize_keras_class_and_config(
|
||||
name, serialization_config, instance)
|
||||
return serialize_keras_class_and_config(name, serialization_config)
|
||||
if hasattr(instance, '__name__'):
|
||||
return get_registered_name(instance)
|
||||
raise ValueError('Cannot serialize', instance)
|
||||
@ -497,9 +286,8 @@ def class_and_config_for_serialized_keras_object(
|
||||
custom_objects=None,
|
||||
printable_module_name='object'):
|
||||
"""Returns the class name and config for a serialized keras object."""
|
||||
if (not isinstance(config, dict)
|
||||
or 'class_name' not in config
|
||||
or 'config' not in config):
|
||||
if (not isinstance(config, dict) or 'class_name' not in config or
|
||||
'config' not in config):
|
||||
raise ValueError('Improper config format: ' + str(config))
|
||||
|
||||
class_name = config['class_name']
|
||||
@ -553,24 +341,7 @@ def deserialize_keras_object(identifier,
|
||||
module_objects=None,
|
||||
custom_objects=None,
|
||||
printable_module_name='object'):
|
||||
"""Turns the serialized form of a Keras object back into an actual object.
|
||||
|
||||
Calls to `deserialize_keras_object` while underneath the
|
||||
`SharedObjectLoadingScope` context manager will cause any already-seen shared
|
||||
objects to be returned as-is rather than creating a new object.
|
||||
|
||||
Args:
|
||||
identifier: the serialized form of the object.
|
||||
module_objects: A dictionary of custom objects to look the name up in.
|
||||
Generally, module_objects is provided by midlevel library implementers.
|
||||
custom_objects: A dictionary of custom objects to look the name up in.
|
||||
Generally, custom_objects is provided by the user.
|
||||
printable_module_name: A human-readable string representing the type of the
|
||||
object. Printed in case of exception.
|
||||
|
||||
Returns:
|
||||
The deserialized object.
|
||||
"""
|
||||
"""Turns the serialized form of a Keras object back into an actual object."""
|
||||
if identifier is None:
|
||||
return None
|
||||
|
||||
@ -580,39 +351,25 @@ def deserialize_keras_object(identifier,
|
||||
(cls, cls_config) = class_and_config_for_serialized_keras_object(
|
||||
config, module_objects, custom_objects, printable_module_name)
|
||||
|
||||
# If this object has already been loaded (i.e. it's shared between multiple
|
||||
# objects), return the already-loaded object.
|
||||
shared_object_id = config.get(SHARED_OBJECT_KEY)
|
||||
shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none
|
||||
if shared_object is not None:
|
||||
return shared_object
|
||||
|
||||
if hasattr(cls, 'from_config'):
|
||||
arg_spec = tf_inspect.getfullargspec(cls.from_config)
|
||||
custom_objects = custom_objects or {}
|
||||
|
||||
if 'custom_objects' in arg_spec.args:
|
||||
deserialized_obj = cls.from_config(
|
||||
return cls.from_config(
|
||||
cls_config,
|
||||
custom_objects=dict(
|
||||
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
|
||||
list(custom_objects.items())))
|
||||
else:
|
||||
with CustomObjectScope(custom_objects):
|
||||
deserialized_obj = cls.from_config(cls_config)
|
||||
with CustomObjectScope(custom_objects):
|
||||
return cls.from_config(cls_config)
|
||||
else:
|
||||
# Then `cls` may be a function returning a class.
|
||||
# in this case by convention `config` holds
|
||||
# the kwargs of the function.
|
||||
custom_objects = custom_objects or {}
|
||||
with CustomObjectScope(custom_objects):
|
||||
deserialized_obj = cls(**cls_config)
|
||||
|
||||
# Add object to shared objects, in case we find it referenced again.
|
||||
_shared_object_loading_scope().set(shared_object_id, deserialized_obj)
|
||||
|
||||
return deserialized_obj
|
||||
|
||||
return cls(**cls_config)
|
||||
elif isinstance(identifier, six.string_types):
|
||||
object_name = identifier
|
||||
if custom_objects and object_name in custom_objects:
|
||||
|
@ -23,7 +23,6 @@ from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -385,63 +384,5 @@ class SliceArraysTest(test.TestCase):
|
||||
[None, None, None])
|
||||
|
||||
|
||||
# object() alone isn't compatible with WeakKeyDictionary, which we use to
|
||||
# track shared configs.
|
||||
class MaybeSharedObject(object):
|
||||
pass
|
||||
|
||||
|
||||
class SharedObjectScopeTest(test.TestCase):
|
||||
|
||||
def test_shared_object_saving_scope_single_object_doesnt_export_id(self):
|
||||
with generic_utils.SharedObjectSavingScope() as scope:
|
||||
single_object = MaybeSharedObject()
|
||||
self.assertIsNone(scope.get_config(single_object))
|
||||
single_object_config = scope.create_config({}, single_object)
|
||||
self.assertIsNotNone(single_object_config)
|
||||
self.assertNotIn(generic_utils.SHARED_OBJECT_KEY,
|
||||
single_object_config)
|
||||
|
||||
def test_shared_object_saving_scope_shared_object_exports_id(self):
|
||||
with generic_utils.SharedObjectSavingScope() as scope:
|
||||
shared_object = MaybeSharedObject()
|
||||
self.assertIsNone(scope.get_config(shared_object))
|
||||
scope.create_config({}, shared_object)
|
||||
first_object_config = scope.get_config(shared_object)
|
||||
second_object_config = scope.get_config(shared_object)
|
||||
self.assertIn(generic_utils.SHARED_OBJECT_KEY,
|
||||
first_object_config)
|
||||
self.assertIn(generic_utils.SHARED_OBJECT_KEY,
|
||||
second_object_config)
|
||||
self.assertIs(first_object_config, second_object_config)
|
||||
|
||||
def test_shared_object_loading_scope_noop(self):
|
||||
# Test that, without a context manager scope, adding configs will do
|
||||
# nothing.
|
||||
obj_id = 1
|
||||
obj = MaybeSharedObject()
|
||||
generic_utils._shared_object_loading_scope().set(obj_id, obj)
|
||||
self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id))
|
||||
|
||||
def test_shared_object_loading_scope_returns_shared_obj(self):
|
||||
obj_id = 1
|
||||
obj = MaybeSharedObject()
|
||||
with generic_utils.SharedObjectLoadingScope() as scope:
|
||||
scope.set(obj_id, obj)
|
||||
self.assertIs(scope.get(obj_id), obj)
|
||||
|
||||
def test_nested_shared_object_saving_scopes(self):
|
||||
my_obj = MaybeSharedObject()
|
||||
with generic_utils.SharedObjectSavingScope() as scope_1:
|
||||
scope_1.create_config({}, my_obj)
|
||||
with generic_utils.SharedObjectSavingScope() as scope_2:
|
||||
# Nesting saving scopes should return the original scope and should
|
||||
# not clear any objects we're tracking.
|
||||
self.assertIs(scope_1, scope_2)
|
||||
self.assertIsNotNone(scope_2.get_config(my_obj))
|
||||
self.assertIsNotNone(scope_1.get_config(my_obj))
|
||||
self.assertIsNone(generic_utils._shared_object_saving_scope())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user