Add support for shared objects in keras when (de)serializing.
Previously, objects shared between multiple layers would be duplicated when saving and loading. This commit adds a unique ID for keras objects when serializing that allows us to correctly create only a single instance of shared objects when deserializing. PiperOrigin-RevId: 354996891 Change-Id: Idfd055f430c7ea7e25c459ed7715a370d7a632c9
This commit is contained in:
parent
21928d2d02
commit
c455215395
@ -671,6 +671,7 @@ class Functional(training_lib.Model):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: In case of improperly formatted config dict.
|
ValueError: In case of improperly formatted config dict.
|
||||||
"""
|
"""
|
||||||
|
with generic_utils.SharedObjectLoadingScope():
|
||||||
input_tensors, output_tensors, created_layers = reconstruct_from_config(
|
input_tensors, output_tensors, created_layers = reconstruct_from_config(
|
||||||
config, custom_objects)
|
config, custom_objects)
|
||||||
model = cls(inputs=input_tensors, outputs=output_tensors,
|
model = cls(inputs=input_tensors, outputs=output_tensors,
|
||||||
@ -1346,6 +1347,8 @@ def get_network_config(network, serialize_layer_fn=None):
|
|||||||
node_conversion_map[node_key] = kept_nodes
|
node_conversion_map[node_key] = kept_nodes
|
||||||
kept_nodes += 1
|
kept_nodes += 1
|
||||||
layer_configs = []
|
layer_configs = []
|
||||||
|
|
||||||
|
with generic_utils.SharedObjectSavingScope():
|
||||||
for layer in network.layers: # From the earliest layers on.
|
for layer in network.layers: # From the earliest layers on.
|
||||||
filtered_inbound_nodes = []
|
filtered_inbound_nodes = []
|
||||||
for original_node_index, node in enumerate(layer._inbound_nodes):
|
for original_node_index, node in enumerate(layer._inbound_nodes):
|
||||||
|
@ -393,6 +393,10 @@ def clone_model(model, input_tensors=None, clone_function=None):
|
|||||||
except that it creates new layers (and thus new weights) instead
|
except that it creates new layers (and thus new weights) instead
|
||||||
of sharing the weights of the existing layers.
|
of sharing the weights of the existing layers.
|
||||||
|
|
||||||
|
`clone_model` will not preserve the uniqueness of shared objects within the
|
||||||
|
model (e.g. a single variable attached to two distinct layers will be
|
||||||
|
restored as two separate variables).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Instance of `Model`
|
model: Instance of `Model`
|
||||||
(could be a functional model or a Sequential model).
|
(could be a functional model or a Sequential model).
|
||||||
@ -420,6 +424,7 @@ def clone_model(model, input_tensors=None, clone_function=None):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: in case of invalid `model` argument value.
|
ValueError: in case of invalid `model` argument value.
|
||||||
"""
|
"""
|
||||||
|
with generic_utils.DisableSharedObjectScope():
|
||||||
if clone_function is None:
|
if clone_function is None:
|
||||||
clone_function = _clone_layer
|
clone_function = _clone_layer
|
||||||
|
|
||||||
|
@ -245,6 +245,28 @@ class TestModelCloning(keras_parameterized.TestCase):
|
|||||||
loss = model.train_on_batch(x, y)
|
loss = model.train_on_batch(x, y)
|
||||||
self.assertEqual(float(loss), 0.)
|
self.assertEqual(float(loss), 0.)
|
||||||
|
|
||||||
|
def test_clone_rnn(self):
|
||||||
|
# Test cloning a model with multiple cells in an RNN. This exercises a
|
||||||
|
# few "fancier" features such as the `Bidrectional` wrapper and
|
||||||
|
# `StackedRNNCells` under the hood.
|
||||||
|
inputs = keras.Input(shape=(3, 3))
|
||||||
|
cells = [
|
||||||
|
keras.layers.LSTMCell(
|
||||||
|
units=32,
|
||||||
|
enable_caching_device=True,
|
||||||
|
implementation=2,
|
||||||
|
activation='relu')]
|
||||||
|
rnn = keras.layers.RNN(cells, return_sequences=True)
|
||||||
|
outputs = keras.layers.Bidirectional(rnn)(inputs)
|
||||||
|
outputs = keras.layers.Dense(
|
||||||
|
12, activation='softmax', name='scores')(outputs)
|
||||||
|
model = keras.Model(inputs=inputs, outputs=outputs)
|
||||||
|
model.compile(
|
||||||
|
loss=keras.losses.CategoricalCrossentropy(),
|
||||||
|
optimizer=keras.optimizer_v2.rmsprop.RMSprop(lr=0.01),
|
||||||
|
metrics=['accuracy'])
|
||||||
|
keras.models.clone_model(model)
|
||||||
|
|
||||||
def test_model_cloning_invalid_use_cases(self):
|
def test_model_cloning_invalid_use_cases(self):
|
||||||
seq_model = keras.models.Sequential()
|
seq_model = keras.models.Sequential()
|
||||||
seq_model.add(keras.layers.Dense(4, input_shape=(4,)))
|
seq_model.add(keras.layers.Dense(4, input_shape=(4,)))
|
||||||
|
@ -148,6 +148,7 @@ def save_model(model,
|
|||||||
hdf5_format.save_model_to_hdf5(
|
hdf5_format.save_model_to_hdf5(
|
||||||
model, filepath, overwrite, include_optimizer)
|
model, filepath, overwrite, include_optimizer)
|
||||||
else:
|
else:
|
||||||
|
with generic_utils.SharedObjectSavingScope():
|
||||||
saved_model_save.save(model, filepath, overwrite, include_optimizer,
|
saved_model_save.save(model, filepath, overwrite, include_optimizer,
|
||||||
signatures, options, save_traces)
|
signatures, options, save_traces)
|
||||||
|
|
||||||
@ -194,6 +195,7 @@ def load_model(filepath, custom_objects=None, compile=True, options=None): # py
|
|||||||
ImportError: if loading from an hdf5 file and h5py is not available.
|
ImportError: if loading from an hdf5 file and h5py is not available.
|
||||||
IOError: In case of an invalid savefile.
|
IOError: In case of an invalid savefile.
|
||||||
"""
|
"""
|
||||||
|
with generic_utils.SharedObjectLoadingScope():
|
||||||
with generic_utils.CustomObjectScope(custom_objects or {}):
|
with generic_utils.CustomObjectScope(custom_objects or {}):
|
||||||
with load_context.load_context(options):
|
with load_context.load_context(options):
|
||||||
if (h5py is not None and
|
if (h5py is not None and
|
||||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
@ -25,12 +26,14 @@ import tempfile
|
|||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from six import string_types
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.feature_column import feature_column_lib
|
from tensorflow.python.feature_column import feature_column_lib
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.keras import combinations
|
from tensorflow.python.keras import combinations
|
||||||
@ -859,6 +862,125 @@ class TestWholeModelSaving(keras_parameterized.TestCase):
|
|||||||
self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
|
self.assertAllEqual(loaded_model.predict(args, batch_size=batch_size),
|
||||||
expected)
|
expected)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['eager']))
|
||||||
|
def test_shared_objects(self):
|
||||||
|
class OuterLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self, inner_layer):
|
||||||
|
super(OuterLayer, self).__init__()
|
||||||
|
self.inner_layer = inner_layer
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return self.inner_layer(inputs)
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return {
|
||||||
|
'inner_layer': generic_utils.serialize_keras_object(
|
||||||
|
self.inner_layer)
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
return cls(generic_utils.deserialize_keras_object(
|
||||||
|
config['inner_layer']))
|
||||||
|
|
||||||
|
class InnerLayer(keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(InnerLayer, self).__init__()
|
||||||
|
self.v = self.add_weight(name='v', shape=[], dtype=dtypes.float32)
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return self.v + inputs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
# Create a model with 2 output layers that share the same inner layer.
|
||||||
|
inner_layer = InnerLayer()
|
||||||
|
outer_layer_1 = OuterLayer(inner_layer)
|
||||||
|
outer_layer_2 = OuterLayer(inner_layer)
|
||||||
|
input_ = keras.Input(shape=(1,))
|
||||||
|
model = keras.Model(
|
||||||
|
inputs=input_, outputs=[outer_layer_1(input_), outer_layer_2(input_)])
|
||||||
|
|
||||||
|
# Changes to the shared layer should affect both outputs.
|
||||||
|
model.layers[1].inner_layer.v.assign(5)
|
||||||
|
self.assertAllEqual(model(1), [6.0, 6.0])
|
||||||
|
model.layers[1].inner_layer.v.assign(3)
|
||||||
|
self.assertAllEqual(model(1), [4.0, 4.0])
|
||||||
|
|
||||||
|
# After loading, changes to the shared layer should still affect both
|
||||||
|
# outputs.
|
||||||
|
def _do_assertions(loaded):
|
||||||
|
loaded.layers[1].inner_layer.v.assign(5)
|
||||||
|
self.assertAllEqual(loaded(1), [6.0, 6.0])
|
||||||
|
loaded.layers[1].inner_layer.v.assign(3)
|
||||||
|
self.assertAllEqual(loaded(1), [4.0, 4.0])
|
||||||
|
loaded.layers[2].inner_layer.v.assign(5)
|
||||||
|
self.assertAllEqual(loaded(1), [6.0, 6.0])
|
||||||
|
loaded.layers[2].inner_layer.v.assign(3)
|
||||||
|
self.assertAllEqual(loaded(1), [4.0, 4.0])
|
||||||
|
|
||||||
|
# We'd like to make sure we only attach shared object IDs when strictly
|
||||||
|
# necessary, so we'll recursively traverse the generated config to count
|
||||||
|
# whether we have the exact number we expect.
|
||||||
|
def _get_all_keys_recursive(dict_or_iterable):
|
||||||
|
if isinstance(dict_or_iterable, dict):
|
||||||
|
for key in dict_or_iterable.keys():
|
||||||
|
yield key
|
||||||
|
for key in _get_all_keys_recursive(dict_or_iterable.values()):
|
||||||
|
yield key
|
||||||
|
elif isinstance(dict_or_iterable, string_types):
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
for item in dict_or_iterable:
|
||||||
|
for key in _get_all_keys_recursive(item):
|
||||||
|
yield key
|
||||||
|
# Not an iterable or dictionary
|
||||||
|
except TypeError:
|
||||||
|
return
|
||||||
|
|
||||||
|
with generic_utils.CustomObjectScope({
|
||||||
|
'OuterLayer': OuterLayer, 'InnerLayer': InnerLayer}):
|
||||||
|
|
||||||
|
# Test saving and loading to disk
|
||||||
|
save_format = testing_utils.get_save_format()
|
||||||
|
saved_model_dir = self._save_model_dir()
|
||||||
|
keras.models.save_model(model, saved_model_dir, save_format=save_format)
|
||||||
|
loaded = keras.models.load_model(saved_model_dir)
|
||||||
|
_do_assertions(loaded)
|
||||||
|
|
||||||
|
# Test recreating directly from config
|
||||||
|
config = model.get_config()
|
||||||
|
key_count = collections.Counter(_get_all_keys_recursive(config))
|
||||||
|
self.assertEqual(key_count[generic_utils.SHARED_OBJECT_KEY], 2)
|
||||||
|
loaded = keras.Model.from_config(config)
|
||||||
|
_do_assertions(loaded)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(mode=['eager']))
|
||||||
|
def test_shared_objects_wrapper(self):
|
||||||
|
"""Tests that shared layers wrapped with `Wrapper` restore correctly."""
|
||||||
|
input_ = keras.Input(shape=(1,))
|
||||||
|
unwrapped = keras.layers.Layer(name='unwrapped')
|
||||||
|
wrapped = keras.layers.Wrapper(unwrapped, name='wrapped')
|
||||||
|
model = keras.Model(inputs=input_,
|
||||||
|
outputs=[unwrapped(input_), wrapped(input_)])
|
||||||
|
|
||||||
|
# Test recreating directly from config
|
||||||
|
config = model.get_config()
|
||||||
|
loaded = keras.Model.from_config(config)
|
||||||
|
self.assertIs(loaded.layers[1], loaded.layers[2].layer)
|
||||||
|
|
||||||
|
# Test saving and loading to disk
|
||||||
|
save_format = testing_utils.get_save_format()
|
||||||
|
saved_model_dir = self._save_model_dir()
|
||||||
|
keras.models.save_model(model, saved_model_dir, save_format=save_format)
|
||||||
|
loaded = keras.models.load_model(saved_model_dir)
|
||||||
|
self.assertIs(loaded.layers[1], loaded.layers[2].layer)
|
||||||
|
|
||||||
|
|
||||||
# Factory functions to create models that will be serialized inside a Network.
|
# Factory functions to create models that will be serialized inside a Network.
|
||||||
def _make_graph_network(input_size, output_size):
|
def _make_graph_network(input_size, output_size):
|
||||||
|
@ -46,7 +46,6 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
|||||||
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
|
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
|
||||||
# the python config serialization has caught up.
|
# the python config serialization has caught up.
|
||||||
metadata = dict(
|
metadata = dict(
|
||||||
class_name=generic_utils.get_registered_name(type(self.obj)),
|
|
||||||
name=self.obj.name,
|
name=self.obj.name,
|
||||||
trainable=self.obj.trainable,
|
trainable=self.obj.trainable,
|
||||||
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
|
expects_training_arg=self.obj._expects_training_arg, # pylint: disable=protected-access
|
||||||
@ -56,7 +55,7 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
|||||||
must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
|
must_restore_from_config=self.obj._must_restore_from_config, # pylint: disable=protected-access
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata.update(get_config(self.obj))
|
metadata.update(get_serialized(self.obj))
|
||||||
if self.obj.input_spec is not None:
|
if self.obj.input_spec is not None:
|
||||||
# Layer's input_spec has already been type-checked in the property setter.
|
# Layer's input_spec has already been type-checked in the property setter.
|
||||||
metadata['input_spec'] = nest.map_structure(
|
metadata['input_spec'] = nest.map_structure(
|
||||||
@ -110,16 +109,12 @@ class LayerSavedModelSaver(base_serialization.SavedModelSaver):
|
|||||||
|
|
||||||
# TODO(kathywu): Move serialization utils (and related utils from
|
# TODO(kathywu): Move serialization utils (and related utils from
|
||||||
# generic_utils.py) to a separate file.
|
# generic_utils.py) to a separate file.
|
||||||
def get_config(obj):
|
def get_serialized(obj):
|
||||||
with generic_utils.skip_failed_serialization():
|
with generic_utils.skip_failed_serialization():
|
||||||
# Store the config dictionary, which may be used when reviving the object.
|
# Store the config dictionary, which may be used when reviving the object.
|
||||||
# When loading, the program will attempt to revive the object from config,
|
# When loading, the program will attempt to revive the object from config,
|
||||||
# and if that fails, the object will be revived from the SavedModel.
|
# and if that fails, the object will be revived from the SavedModel.
|
||||||
config = generic_utils.serialize_keras_object(obj)['config']
|
return generic_utils.serialize_keras_object(obj)
|
||||||
|
|
||||||
if config is not None:
|
|
||||||
return {'config': config}
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
|
class InputLayerSavedModelSaver(base_serialization.SavedModelSaver):
|
||||||
|
@ -493,13 +493,15 @@ class KerasObjectLoader(object):
|
|||||||
# found.
|
# found.
|
||||||
class_name = metadata.get('class_name')
|
class_name = metadata.get('class_name')
|
||||||
config = metadata.get('config')
|
config = metadata.get('config')
|
||||||
|
shared_object_id = metadata.get('shared_object_id')
|
||||||
must_restore_from_config = metadata.get('must_restore_from_config')
|
must_restore_from_config = metadata.get('must_restore_from_config')
|
||||||
if not generic_utils.validate_config(config):
|
if not generic_utils.validate_config(config):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj = layers_module.deserialize(
|
obj = layers_module.deserialize(
|
||||||
generic_utils.serialize_keras_class_and_config(class_name, config))
|
generic_utils.serialize_keras_class_and_config(
|
||||||
|
class_name, config, shared_object_id=shared_object_id))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
if must_restore_from_config:
|
if must_restore_from_config:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
@ -36,7 +36,7 @@ class MetricSavedModelSaver(layer_serialization.LayerSavedModelSaver):
|
|||||||
class_name=generic_utils.get_registered_name(type(self.obj)),
|
class_name=generic_utils.get_registered_name(type(self.obj)),
|
||||||
name=self.obj.name,
|
name=self.obj.name,
|
||||||
dtype=self.obj.dtype)
|
dtype=self.obj.dtype)
|
||||||
metadata.update(layer_serialization.get_config(self.obj))
|
metadata.update(layer_serialization.get_serialized(self.obj))
|
||||||
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
|
if self.obj._build_input_shape is not None: # pylint: disable=protected-access
|
||||||
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
|
metadata['build_input_shape'] = self.obj._build_input_shape # pylint: disable=protected-access
|
||||||
return metadata
|
return metadata
|
||||||
|
@ -24,8 +24,10 @@ import marshal
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import types as python_types
|
import types as python_types
|
||||||
|
import weakref
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import six
|
import six
|
||||||
@ -110,9 +112,235 @@ def get_custom_objects():
|
|||||||
return _GLOBAL_CUSTOM_OBJECTS
|
return _GLOBAL_CUSTOM_OBJECTS
|
||||||
|
|
||||||
|
|
||||||
def serialize_keras_class_and_config(cls_name, cls_config):
|
# Store a unique, per-object ID for shared objects.
|
||||||
|
#
|
||||||
|
# We store a unique ID for each object so that we may, at loading time,
|
||||||
|
# re-create the network properly. Without this ID, we would have no way of
|
||||||
|
# determining whether a config is a description of a new object that
|
||||||
|
# should be created or is merely a reference to an already-created object.
|
||||||
|
SHARED_OBJECT_KEY = 'shared_object_id'
|
||||||
|
|
||||||
|
|
||||||
|
SHARED_OBJECT_DISABLED = threading.local()
|
||||||
|
SHARED_OBJECT_LOADING = threading.local()
|
||||||
|
SHARED_OBJECT_SAVING = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
# Attributes on the threadlocal variable must be set per-thread, thus we
|
||||||
|
# cannot initialize these globally. Instead, we have accessor functions with
|
||||||
|
# default values.
|
||||||
|
def _shared_object_disabled():
|
||||||
|
"""Get whether shared object handling is disabled in a threadsafe manner."""
|
||||||
|
return getattr(SHARED_OBJECT_DISABLED, 'disabled', False)
|
||||||
|
|
||||||
|
|
||||||
|
def _shared_object_loading_scope():
|
||||||
|
"""Get the current shared object saving scope in a threadsafe manner."""
|
||||||
|
return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
|
||||||
|
|
||||||
|
|
||||||
|
def _shared_object_saving_scope():
|
||||||
|
"""Get the current shared object saving scope in a threadsafe manner."""
|
||||||
|
return getattr(SHARED_OBJECT_SAVING, 'scope', None)
|
||||||
|
|
||||||
|
|
||||||
|
class DisableSharedObjectScope(object):
|
||||||
|
"""A context manager for disabling handling of shared objects.
|
||||||
|
|
||||||
|
Disables shared object handling for both saving and loading.
|
||||||
|
|
||||||
|
Created primarily for use with `clone_model`, which does extra surgery that
|
||||||
|
is incompatible with shared objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
SHARED_OBJECT_DISABLED.disabled = True
|
||||||
|
self._orig_loading_scope = _shared_object_loading_scope()
|
||||||
|
self._orig_saving_scope = _shared_object_saving_scope()
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
SHARED_OBJECT_DISABLED.disabled = False
|
||||||
|
SHARED_OBJECT_LOADING.scope = self._orig_loading_scope
|
||||||
|
SHARED_OBJECT_SAVING.scope = self._orig_saving_scope
|
||||||
|
|
||||||
|
|
||||||
|
class NoopLoadingScope(object):
|
||||||
|
"""The default shared object loading scope. It does nothing.
|
||||||
|
|
||||||
|
Created to simplify serialization code that doesn't care about shared objects
|
||||||
|
(e.g. when serializing a single object).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get(self, unused_object_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, object_id, obj):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SharedObjectLoadingScope(object):
|
||||||
|
"""A context manager for keeping track of loaded objects.
|
||||||
|
|
||||||
|
During the deserialization process, we may come across objects that are
|
||||||
|
shared across multiple layers. In order to accurately restore the network
|
||||||
|
structure to its original state, `SharedObjectLoadingScope` allows us to
|
||||||
|
re-use shared objects rather than cloning them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if _shared_object_disabled():
|
||||||
|
return NoopLoadingScope()
|
||||||
|
|
||||||
|
global SHARED_OBJECT_LOADING
|
||||||
|
SHARED_OBJECT_LOADING.scope = self
|
||||||
|
self._obj_ids_to_obj = {}
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get(self, object_id):
|
||||||
|
"""Given a shared object ID, returns a previously instantiated object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
object_id: shared object ID to use when attempting to find already-loaded
|
||||||
|
object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The object, if we've seen this ID before. Else, `None`.
|
||||||
|
"""
|
||||||
|
# Explicitly check for `None` internally to make external calling code a
|
||||||
|
# bit cleaner.
|
||||||
|
if object_id is None:
|
||||||
|
return
|
||||||
|
return self._obj_ids_to_obj.get(object_id)
|
||||||
|
|
||||||
|
def set(self, object_id, obj):
|
||||||
|
"""Stores an instantiated object for future lookup and sharing."""
|
||||||
|
if object_id is None:
|
||||||
|
return
|
||||||
|
self._obj_ids_to_obj[object_id] = obj
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
global SHARED_OBJECT_LOADING
|
||||||
|
SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
|
||||||
|
|
||||||
|
|
||||||
|
class SharedObjectConfig(dict):
|
||||||
|
"""A configuration container that keeps track of references.
|
||||||
|
|
||||||
|
`SharedObjectConfig` will automatically attach a shared object ID to any
|
||||||
|
configs which are referenced more than once, allowing for proper shared
|
||||||
|
object reconstruction at load time.
|
||||||
|
|
||||||
|
In most cases, it would be more proper to subclass something like
|
||||||
|
`collections.UserDict` or `collections.Mapping` rather than `dict` directly.
|
||||||
|
Unfortunately, python's json encoder does not support `Mapping`s. This is
|
||||||
|
important functionality to retain, since we are dealing with serialization.
|
||||||
|
|
||||||
|
We should be safe to subclass `dict` here, since we aren't actually
|
||||||
|
overriding any core methods, only augmenting with a new one for reference
|
||||||
|
counting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, base_config, object_id, **kwargs):
|
||||||
|
self.ref_count = 1
|
||||||
|
self.object_id = object_id
|
||||||
|
super(SharedObjectConfig, self).__init__(base_config, **kwargs)
|
||||||
|
|
||||||
|
def increment_ref_count(self):
|
||||||
|
# As soon as we've seen the object more than once, we want to attach the
|
||||||
|
# shared object ID. This allows us to only attach the shared object ID when
|
||||||
|
# it's strictly necessary, making backwards compatibility breakage less
|
||||||
|
# likely.
|
||||||
|
if self.ref_count == 1:
|
||||||
|
self[SHARED_OBJECT_KEY] = self.object_id
|
||||||
|
self.ref_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
class SharedObjectSavingScope(object):
|
||||||
|
"""Keeps track of shared object configs when serializing."""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
if _shared_object_disabled():
|
||||||
|
return None
|
||||||
|
|
||||||
|
global SHARED_OBJECT_SAVING
|
||||||
|
|
||||||
|
# Serialization can happen at a number of layers for a number of reasons.
|
||||||
|
# We may end up with a case where we're opening a saving scope within
|
||||||
|
# another saving scope. In that case, we'd like to use the outermost scope
|
||||||
|
# available and ignore inner scopes, since there is not (yet) a reasonable
|
||||||
|
# use case for having these nested and distinct.
|
||||||
|
if _shared_object_saving_scope() is not None:
|
||||||
|
self._passthrough = True
|
||||||
|
return _shared_object_saving_scope()
|
||||||
|
else:
|
||||||
|
self._passthrough = False
|
||||||
|
|
||||||
|
SHARED_OBJECT_SAVING.scope = self
|
||||||
|
self._shared_objects_config = weakref.WeakKeyDictionary()
|
||||||
|
self._next_id = 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
def get_config(self, obj):
|
||||||
|
"""Gets a `SharedObjectConfig` if one has already been seen for `obj`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object for which to retrieve the `SharedObjectConfig`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The SharedObjectConfig for a given object, if already seen. Else,
|
||||||
|
`None`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
shared_object_config = self._shared_objects_config[obj]
|
||||||
|
except (TypeError, KeyError):
|
||||||
|
# If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
|
||||||
|
# that has not overridden `__hash__`), a `TypeError` will be thrown.
|
||||||
|
# We'll just continue on without shared object support.
|
||||||
|
return None
|
||||||
|
shared_object_config.increment_ref_count()
|
||||||
|
return shared_object_config
|
||||||
|
|
||||||
|
def create_config(self, base_config, obj):
|
||||||
|
"""Create a new SharedObjectConfig for a given object."""
|
||||||
|
shared_object_config = SharedObjectConfig(base_config, self._next_id)
|
||||||
|
self._next_id += 1
|
||||||
|
try:
|
||||||
|
self._shared_objects_config[obj] = shared_object_config
|
||||||
|
except TypeError:
|
||||||
|
# If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
|
||||||
|
# that has not overridden `__hash__`), a `TypeError` will be thrown.
|
||||||
|
# We'll just continue on without shared object support.
|
||||||
|
pass
|
||||||
|
return shared_object_config
|
||||||
|
|
||||||
|
def __exit__(self, *args, **kwargs):
|
||||||
|
if not getattr(self, '_passthrough', False):
|
||||||
|
global SHARED_OBJECT_SAVING
|
||||||
|
SHARED_OBJECT_SAVING.scope = None
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_keras_class_and_config(
|
||||||
|
cls_name, cls_config, obj=None, shared_object_id=None):
|
||||||
"""Returns the serialization of the class with the given config."""
|
"""Returns the serialization of the class with the given config."""
|
||||||
return {'class_name': cls_name, 'config': cls_config}
|
base_config = {'class_name': cls_name, 'config': cls_config}
|
||||||
|
|
||||||
|
# We call `serialize_keras_class_and_config` for some branches of the load
|
||||||
|
# path. In that case, we may already have a shared object ID we'd like to
|
||||||
|
# retain.
|
||||||
|
if shared_object_id is not None:
|
||||||
|
base_config[SHARED_OBJECT_KEY] = shared_object_id
|
||||||
|
|
||||||
|
# If we have an active `SharedObjectSavingScope`, check whether we've already
|
||||||
|
# serialized this config. If so, just use that config. This will store an
|
||||||
|
# extra ID field in the config, allowing us to re-create the shared object
|
||||||
|
# relationship at load time.
|
||||||
|
if _shared_object_saving_scope() is not None and obj is not None:
|
||||||
|
shared_object_config = _shared_object_saving_scope().get_config(obj)
|
||||||
|
if shared_object_config is None:
|
||||||
|
return _shared_object_saving_scope().create_config(base_config, obj)
|
||||||
|
return shared_object_config
|
||||||
|
|
||||||
|
return base_config
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.utils.register_keras_serializable')
|
@keras_export('keras.utils.register_keras_serializable')
|
||||||
@ -234,7 +462,19 @@ def get_registered_object(name, custom_objects=None, module_objects=None):
|
|||||||
|
|
||||||
@keras_export('keras.utils.serialize_keras_object')
|
@keras_export('keras.utils.serialize_keras_object')
|
||||||
def serialize_keras_object(instance):
|
def serialize_keras_object(instance):
|
||||||
"""Serialize a Keras object into a JSON-compatible representation."""
|
"""Serialize a Keras object into a JSON-compatible representation.
|
||||||
|
|
||||||
|
Calls to `serialize_keras_object` while underneath the
|
||||||
|
`SharedObjectSavingScope` context manager will cause any objects re-used
|
||||||
|
across multiple layers to be saved with a special shared object ID. This
|
||||||
|
allows the network to be re-created properly during deserialization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instance: The object to serialize.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dict-like, JSON-compatible representation of the object's config.
|
||||||
|
"""
|
||||||
_, instance = tf_decorator.unwrap(instance)
|
_, instance = tf_decorator.unwrap(instance)
|
||||||
if instance is None:
|
if instance is None:
|
||||||
return None
|
return None
|
||||||
@ -265,7 +505,8 @@ def serialize_keras_object(instance):
|
|||||||
serialization_config[key] = item
|
serialization_config[key] = item
|
||||||
|
|
||||||
name = get_registered_name(instance.__class__)
|
name = get_registered_name(instance.__class__)
|
||||||
return serialize_keras_class_and_config(name, serialization_config)
|
return serialize_keras_class_and_config(
|
||||||
|
name, serialization_config, instance)
|
||||||
if hasattr(instance, '__name__'):
|
if hasattr(instance, '__name__'):
|
||||||
return get_registered_name(instance)
|
return get_registered_name(instance)
|
||||||
raise ValueError('Cannot serialize', instance)
|
raise ValueError('Cannot serialize', instance)
|
||||||
@ -286,8 +527,9 @@ def class_and_config_for_serialized_keras_object(
|
|||||||
custom_objects=None,
|
custom_objects=None,
|
||||||
printable_module_name='object'):
|
printable_module_name='object'):
|
||||||
"""Returns the class name and config for a serialized keras object."""
|
"""Returns the class name and config for a serialized keras object."""
|
||||||
if (not isinstance(config, dict) or 'class_name' not in config or
|
if (not isinstance(config, dict)
|
||||||
'config' not in config):
|
or 'class_name' not in config
|
||||||
|
or 'config' not in config):
|
||||||
raise ValueError('Improper config format: ' + str(config))
|
raise ValueError('Improper config format: ' + str(config))
|
||||||
|
|
||||||
class_name = config['class_name']
|
class_name = config['class_name']
|
||||||
@ -341,7 +583,24 @@ def deserialize_keras_object(identifier,
|
|||||||
module_objects=None,
|
module_objects=None,
|
||||||
custom_objects=None,
|
custom_objects=None,
|
||||||
printable_module_name='object'):
|
printable_module_name='object'):
|
||||||
"""Turns the serialized form of a Keras object back into an actual object."""
|
"""Turns the serialized form of a Keras object back into an actual object.
|
||||||
|
|
||||||
|
Calls to `deserialize_keras_object` while underneath the
|
||||||
|
`SharedObjectLoadingScope` context manager will cause any already-seen shared
|
||||||
|
objects to be returned as-is rather than creating a new object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: the serialized form of the object.
|
||||||
|
module_objects: A dictionary of custom objects to look the name up in.
|
||||||
|
Generally, module_objects is provided by midlevel library implementers.
|
||||||
|
custom_objects: A dictionary of custom objects to look the name up in.
|
||||||
|
Generally, custom_objects is provided by the user.
|
||||||
|
printable_module_name: A human-readable string representing the type of the
|
||||||
|
object. Printed in case of exception.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The deserialized object.
|
||||||
|
"""
|
||||||
if identifier is None:
|
if identifier is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -351,25 +610,39 @@ def deserialize_keras_object(identifier,
|
|||||||
(cls, cls_config) = class_and_config_for_serialized_keras_object(
|
(cls, cls_config) = class_and_config_for_serialized_keras_object(
|
||||||
config, module_objects, custom_objects, printable_module_name)
|
config, module_objects, custom_objects, printable_module_name)
|
||||||
|
|
||||||
|
# If this object has already been loaded (i.e. it's shared between multiple
|
||||||
|
# objects), return the already-loaded object.
|
||||||
|
shared_object_id = config.get(SHARED_OBJECT_KEY)
|
||||||
|
shared_object = _shared_object_loading_scope().get(shared_object_id) # pylint: disable=assignment-from-none
|
||||||
|
if shared_object is not None:
|
||||||
|
return shared_object
|
||||||
|
|
||||||
if hasattr(cls, 'from_config'):
|
if hasattr(cls, 'from_config'):
|
||||||
arg_spec = tf_inspect.getfullargspec(cls.from_config)
|
arg_spec = tf_inspect.getfullargspec(cls.from_config)
|
||||||
custom_objects = custom_objects or {}
|
custom_objects = custom_objects or {}
|
||||||
|
|
||||||
if 'custom_objects' in arg_spec.args:
|
if 'custom_objects' in arg_spec.args:
|
||||||
return cls.from_config(
|
deserialized_obj = cls.from_config(
|
||||||
cls_config,
|
cls_config,
|
||||||
custom_objects=dict(
|
custom_objects=dict(
|
||||||
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
|
list(_GLOBAL_CUSTOM_OBJECTS.items()) +
|
||||||
list(custom_objects.items())))
|
list(custom_objects.items())))
|
||||||
|
else:
|
||||||
with CustomObjectScope(custom_objects):
|
with CustomObjectScope(custom_objects):
|
||||||
return cls.from_config(cls_config)
|
deserialized_obj = cls.from_config(cls_config)
|
||||||
else:
|
else:
|
||||||
# Then `cls` may be a function returning a class.
|
# Then `cls` may be a function returning a class.
|
||||||
# in this case by convention `config` holds
|
# in this case by convention `config` holds
|
||||||
# the kwargs of the function.
|
# the kwargs of the function.
|
||||||
custom_objects = custom_objects or {}
|
custom_objects = custom_objects or {}
|
||||||
with CustomObjectScope(custom_objects):
|
with CustomObjectScope(custom_objects):
|
||||||
return cls(**cls_config)
|
deserialized_obj = cls(**cls_config)
|
||||||
|
|
||||||
|
# Add object to shared objects, in case we find it referenced again.
|
||||||
|
_shared_object_loading_scope().set(shared_object_id, deserialized_obj)
|
||||||
|
|
||||||
|
return deserialized_obj
|
||||||
|
|
||||||
elif isinstance(identifier, six.string_types):
|
elif isinstance(identifier, six.string_types):
|
||||||
object_name = identifier
|
object_name = identifier
|
||||||
if custom_objects and object_name in custom_objects:
|
if custom_objects and object_name in custom_objects:
|
||||||
|
@ -23,6 +23,7 @@ from functools import partial
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
|
from tensorflow.python.keras.utils import generic_utils
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -384,5 +385,63 @@ class SliceArraysTest(test.TestCase):
|
|||||||
[None, None, None])
|
[None, None, None])
|
||||||
|
|
||||||
|
|
||||||
|
# object() alone isn't compatible with WeakKeyDictionary, which we use to
|
||||||
|
# track shared configs.
|
||||||
|
class MaybeSharedObject(object):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SharedObjectScopeTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_shared_object_saving_scope_single_object_doesnt_export_id(self):
|
||||||
|
with generic_utils.SharedObjectSavingScope() as scope:
|
||||||
|
single_object = MaybeSharedObject()
|
||||||
|
self.assertIsNone(scope.get_config(single_object))
|
||||||
|
single_object_config = scope.create_config({}, single_object)
|
||||||
|
self.assertIsNotNone(single_object_config)
|
||||||
|
self.assertNotIn(generic_utils.SHARED_OBJECT_KEY,
|
||||||
|
single_object_config)
|
||||||
|
|
||||||
|
def test_shared_object_saving_scope_shared_object_exports_id(self):
|
||||||
|
with generic_utils.SharedObjectSavingScope() as scope:
|
||||||
|
shared_object = MaybeSharedObject()
|
||||||
|
self.assertIsNone(scope.get_config(shared_object))
|
||||||
|
scope.create_config({}, shared_object)
|
||||||
|
first_object_config = scope.get_config(shared_object)
|
||||||
|
second_object_config = scope.get_config(shared_object)
|
||||||
|
self.assertIn(generic_utils.SHARED_OBJECT_KEY,
|
||||||
|
first_object_config)
|
||||||
|
self.assertIn(generic_utils.SHARED_OBJECT_KEY,
|
||||||
|
second_object_config)
|
||||||
|
self.assertIs(first_object_config, second_object_config)
|
||||||
|
|
||||||
|
def test_shared_object_loading_scope_noop(self):
|
||||||
|
# Test that, without a context manager scope, adding configs will do
|
||||||
|
# nothing.
|
||||||
|
obj_id = 1
|
||||||
|
obj = MaybeSharedObject()
|
||||||
|
generic_utils._shared_object_loading_scope().set(obj_id, obj)
|
||||||
|
self.assertIsNone(generic_utils._shared_object_loading_scope().get(obj_id))
|
||||||
|
|
||||||
|
def test_shared_object_loading_scope_returns_shared_obj(self):
|
||||||
|
obj_id = 1
|
||||||
|
obj = MaybeSharedObject()
|
||||||
|
with generic_utils.SharedObjectLoadingScope() as scope:
|
||||||
|
scope.set(obj_id, obj)
|
||||||
|
self.assertIs(scope.get(obj_id), obj)
|
||||||
|
|
||||||
|
def test_nested_shared_object_saving_scopes(self):
|
||||||
|
my_obj = MaybeSharedObject()
|
||||||
|
with generic_utils.SharedObjectSavingScope() as scope_1:
|
||||||
|
scope_1.create_config({}, my_obj)
|
||||||
|
with generic_utils.SharedObjectSavingScope() as scope_2:
|
||||||
|
# Nesting saving scopes should return the original scope and should
|
||||||
|
# not clear any objects we're tracking.
|
||||||
|
self.assertIs(scope_1, scope_2)
|
||||||
|
self.assertIsNotNone(scope_2.get_config(my_obj))
|
||||||
|
self.assertIsNotNone(scope_1.get_config(my_obj))
|
||||||
|
self.assertIsNone(generic_utils._shared_object_saving_scope())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user