(rollforward) Add option to not save the traces when exporting to the SavedModel format.

The current tracing implementation has a limited scope of models and layers that can be traced. When users add a custom layer or model that is unsupported (e.g. with multiple tensor arguments), they'll come across an error that prevents them from saving the model entirely. With this option, those models can now be saved to the SavedModel format for serving or retraining.

PiperOrigin-RevId: 338295627
Change-Id: Ieea88ecaa1b8665df4ab45c96e882867e4308d88
This commit is contained in:
Katherine Wu 2020-10-21 10:50:41 -07:00 committed by TensorFlower Gardener
parent 7d57263720
commit 259ffa9ea6
26 changed files with 284 additions and 114 deletions

View File

@ -270,6 +270,12 @@
* For Keras model, the individual call of `Model.evaluate` uses no cached * For Keras model, the individual call of `Model.evaluate` uses no cached
data for evaluation, while `Model.fit` uses cached data when data for evaluation, while `Model.fit` uses cached data when
`validation_data` arg is provided for better performance. `validation_data` arg is provided for better performance.
* Added a `save_traces` argument to `model.save`/
`tf.keras.models.save_model` which determines whether the SavedModel
format stores the Keras model/layer call functions. The traced functions
allow Keras to revive custom models and layers without the original
class definition, but if this isn't required the tracing can be
disabled with the added option.
* `tf.function` / AutoGraph: * `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When * Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing True, the function may use type annotations to optimize the tracing

View File

@ -2104,9 +2104,10 @@ class Layer(base_layer.Layer):
# operations. # operations.
with tf_utils.maybe_init_scope(self): with tf_utils.maybe_init_scope(self):
self.build(input_shapes) self.build(input_shapes)
# We must set self.built since user defined build functions are not # We must set also ensure that the layer is marked as built, and the build
# constrained to set self.built. # shape is stored since user defined build functions may not be calling
self.built = True # `super.build()`
Layer.build(self, input_shapes)
# Optionally load weight values specified at layer instantiation. # Optionally load weight values specified at layer instantiation.
if self._initial_weights is not None: if self._initial_weights is not None:

View File

@ -1953,31 +1953,14 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
include_optimizer=True, include_optimizer=True,
save_format=None, save_format=None,
signatures=None, signatures=None,
options=None): options=None,
save_traces=True):
# pylint: disable=line-too-long
"""Saves the model to Tensorflow SavedModel or a single HDF5 file. """Saves the model to Tensorflow SavedModel or a single HDF5 file.
The savefile includes: Please see `tf.keras.models.save_model` or the
[Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/)
- The model architecture, allowing to re-instantiate the model. for details.
- The model weights.
- The state of the optimizer, allowing to resume training
exactly where you left off.
This allows you to save the entirety of the state of a model
in a single file.
Saved models can be re-instantiated via `keras.models.load_model`.
The model returned by `load_model` is a compiled model ready to be used
(unless the saved model was never compiled in the first place).
Models built with the Sequential and Functional API can be saved to both the
HDF5 and SavedModel formats. Subclassed models can only be saved with the
SavedModel format.
Note that the model weights may have different scoped names after being
loaded. Scoped names include the model/layer names, such as
`"dense_1/kernel:0"`. It is recommended that you use the layer properties to
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
Arguments: Arguments:
filepath: String, PathLike, path to SavedModel or H5 file to save the filepath: String, PathLike, path to SavedModel or H5 file to save the
@ -1991,8 +1974,15 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
signatures: Signatures to save with the SavedModel. Applicable to the signatures: Signatures to save with the SavedModel. Applicable to the
'tf' format only. Please see the `signatures` argument in 'tf' format only. Please see the `signatures` argument in
`tf.saved_model.save` for details. `tf.saved_model.save` for details.
options: Optional `tf.saved_model.SaveOptions` object that specifies options: (only applies to SavedModel format)
options for saving to SavedModel. `tf.saved_model.SaveOptions` object that specifies options for
saving to SavedModel.
save_traces: (only applies to SavedModel format) When enabled, the
SavedModel will store the function traces for each layer. This
can be disabled, so that only the configs of each layer are stored.
Defaults to `True`. Disabling this will decrease serialization time
and reduce file size, but it requires that all custom layers/models
implement a `get_config()` method.
Example: Example:
@ -2007,8 +1997,9 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
model = load_model('my_model.h5') model = load_model('my_model.h5')
``` ```
""" """
# pylint: enable=line-too-long
save.save_model(self, filepath, overwrite, include_optimizer, save_format, save.save_model(self, filepath, overwrite, include_optimizer, save_format,
signatures, options) signatures, options, save_traces)
def save_weights(self, def save_weights(self,
filepath, filepath,

View File

@ -113,7 +113,6 @@ def run_with_all_saved_model_formats(
tf.test.main() tf.test.main()
``` ```
Args: Args:
test_or_class: test method or class to be annotated. If None, test_or_class: test method or class to be annotated. If None,
this method returns a decorator that can be applied to a test method or this method returns a decorator that can be applied to a test method or
@ -134,7 +133,7 @@ def run_with_all_saved_model_formats(
# Exclude h5 save format if H5py isn't available. # Exclude h5 save format if H5py isn't available.
if h5py is None: if h5py is None:
exclude_formats.append(['h5']) exclude_formats.append(['h5'])
saved_model_formats = ['h5', 'tf'] saved_model_formats = ['h5', 'tf', 'tf_no_traces']
params = [('_%s' % saved_format, saved_format) params = [('_%s' % saved_format, saved_format)
for saved_format in saved_model_formats for saved_format in saved_model_formats
if saved_format not in nest.flatten(exclude_formats)] if saved_format not in nest.flatten(exclude_formats)]
@ -150,6 +149,8 @@ def run_with_all_saved_model_formats(
_test_h5_saved_model_format(f, self, *args, **kwargs) _test_h5_saved_model_format(f, self, *args, **kwargs)
elif saved_format == 'tf': elif saved_format == 'tf':
_test_tf_saved_model_format(f, self, *args, **kwargs) _test_tf_saved_model_format(f, self, *args, **kwargs)
elif saved_format == 'tf_no_traces':
_test_tf_saved_model_format_no_traces(f, self, *args, **kwargs)
else: else:
raise ValueError('Unknown model type: %s' % (saved_format,)) raise ValueError('Unknown model type: %s' % (saved_format,))
return decorated return decorated
@ -167,6 +168,18 @@ def _test_tf_saved_model_format(f, test_or_class, *args, **kwargs):
f(test_or_class, *args, **kwargs) f(test_or_class, *args, **kwargs)
def _test_tf_saved_model_format_no_traces(f, test_or_class, *args, **kwargs):
with testing_utils.saved_model_format_scope('tf', save_traces=False):
f(test_or_class, *args, **kwargs)
def run_with_all_weight_formats(test_or_class=None, exclude_formats=None):
"""Runs all tests with the supported formats for saving weights."""
exclude_formats = exclude_formats or []
exclude_formats.append('tf_no_traces') # Only applies to saving models
return run_with_all_saved_model_formats(test_or_class, exclude_formats)
# TODO(kaftan): Possibly enable 'subclass_custom_build' when tests begin to pass # TODO(kaftan): Possibly enable 'subclass_custom_build' when tests begin to pass
# it. Or perhaps make 'subclass' always use a custom build method. # it. Or perhaps make 'subclass' always use a custom build method.
def run_with_all_model_types( def run_with_all_model_types(

View File

@ -58,7 +58,7 @@ except ImportError:
@combinations.generate(combinations.combine(mode=['graph', 'eager'])) @combinations.generate(combinations.combine(mode=['graph', 'eager']))
class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase): class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
@keras_parameterized.run_with_all_saved_model_formats @keras_parameterized.run_with_all_weight_formats
def test_weight_loading(self): def test_weight_loading(self):
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir) self.addCleanup(shutil.rmtree, temp_dir)
@ -410,9 +410,14 @@ class TestWholeModelSaving(keras_parameterized.TestCase):
def test_save_and_load(self): def test_save_and_load(self):
saved_model_dir = self._save_model_dir() saved_model_dir = self._save_model_dir()
save_format = testing_utils.get_save_format() save_format = testing_utils.get_save_format()
save_kwargs = testing_utils.get_save_kwargs()
if save_format == 'h5' and testing_utils.get_model_type() == 'subclass': if ((save_format == 'h5' or not save_kwargs.get('save_traces', True)) and
return # HDF5 format currently does not allow saving classed models. testing_utils.get_model_type() == 'subclass'):
# HDF5 format currently does not allow saving subclassed models.
# When saving with `save_traces=False`, the subclassed model must have a
# get_config/from_config, which the autogenerated model does not have.
return
with self.cached_session(): with self.cached_session():
model = testing_utils.get_model_from_layers( model = testing_utils.get_model_from_layers(
@ -440,7 +445,9 @@ class TestWholeModelSaving(keras_parameterized.TestCase):
model.train_on_batch(x, y) model.train_on_batch(x, y)
out = model.predict(x) out = model.predict(x)
keras.models.save_model(model, saved_model_dir, save_format=save_format) keras.models.save_model(
model, saved_model_dir, save_format=save_format,
**save_kwargs)
loaded_model = keras.models.load_model(saved_model_dir) loaded_model = keras.models.load_model(saved_model_dir)
self._assert_same_weights_and_metrics(model, loaded_model) self._assert_same_weights_and_metrics(model, loaded_model)

View File

@ -52,9 +52,14 @@ def save_model(model,
include_optimizer=True, include_optimizer=True,
save_format=None, save_format=None,
signatures=None, signatures=None,
options=None): options=None,
save_traces=True):
# pylint: disable=line-too-long
"""Saves a model as a TensorFlow SavedModel or HDF5 file. """Saves a model as a TensorFlow SavedModel or HDF5 file.
See the [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/)
for details.
Usage: Usage:
>>> model = tf.keras.Sequential([ >>> model = tf.keras.Sequential([
@ -65,28 +70,38 @@ def save_model(model,
>>> x = tf.random.uniform((10, 3)) >>> x = tf.random.uniform((10, 3))
>>> assert np.allclose(model.predict(x), loaded_model.predict(x)) >>> assert np.allclose(model.predict(x), loaded_model.predict(x))
The saved model contains: The SavedModel and HDF5 file contains:
- the model's configuration (topology) - the model's configuration (topology)
- the model's weights - the model's weights
- the model's optimizer's state (if any) - the model's optimizer's state (if any)
Thus the saved model can be reinstantiated in Thus models can be reinstantiated in the exact same state, without any of the
the exact same state, without any of the code code used for model definition or training.
used for model definition or training.
Note that the model weights may have different scoped names after being Note that the model weights may have different scoped names after being
loaded. Scoped names include the model/layer names, such as loaded. Scoped names include the model/layer names, such as
`"dense_1/kernel:0"`. It is recommended that you use the layer properties to `"dense_1/kernel:0"`. It is recommended that you use the layer properties to
access specific variables, e.g. `model.get_layer("dense_1").kernel`. access specific variables, e.g. `model.get_layer("dense_1").kernel`.
_SavedModel serialization_ __SavedModel serialization format__
The SavedModel serialization path uses `tf.saved_model.save` to save the model Keras SavedModel uses `tf.saved_model.save` to save the model and all
and all trackable objects attached to the model (e.g. layers and variables). trackable objects attached to the model (e.g. layers and variables). The model
`@tf.function`-decorated methods are also saved. Additional trackable objects config, weights, and optimizer are saved in the SavedModel. Additionally, for
and functions are added to the SavedModel to allow the model to be every Keras layer attached to the model, the SavedModel stores:
loaded back as a Keras Model object.
* the config and metadata -- e.g. name, dtype, trainable status
* traced call and loss functions, which are stored as TensorFlow subgraphs.
The traced functions allow the SavedModel format to save and load custom
layers without the original class definition.
You can choose to not save the traced functions by disabling the `save_traces`
option. This will decrease the time it takes to save the model and the
amount of disk space occupied by the output SavedModel. If you enable this
option, then you _must_ provide all custom class definitions when loading
the model. See the `custom_objects` argument in `tf.keras.models.load_model`.
Arguments: Arguments:
model: Keras model instance to be saved. model: Keras model instance to be saved.
@ -102,12 +117,19 @@ def save_model(model,
signatures: Signatures to save with the SavedModel. Applicable to the 'tf' signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
format only. Please see the `signatures` argument in format only. Please see the `signatures` argument in
`tf.saved_model.save` for details. `tf.saved_model.save` for details.
options: Optional `tf.saved_model.SaveOptions` object that specifies options: (only applies to SavedModel format) `tf.saved_model.SaveOptions`
options for saving to SavedModel. object that specifies options for saving to SavedModel.
save_traces: (only applies to SavedModel format) When enabled, the
SavedModel will store the function traces for each layer. This
can be disabled, so that only the configs of each layer are stored.
Defaults to `True`. Disabling this will decrease serialization time and
reduce file size, but it requires that all custom layers/models
implement a `get_config()` method.
Raises: Raises:
ImportError: If save format is hdf5, and h5py is not available. ImportError: If save format is hdf5, and h5py is not available.
""" """
# pylint: enable=line-too-long
from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top
default_format = 'tf' if tf2.enabled() else 'h5' default_format = 'tf' if tf2.enabled() else 'h5'
@ -132,7 +154,7 @@ def save_model(model,
model, filepath, overwrite, include_optimizer) model, filepath, overwrite, include_optimizer)
else: else:
saved_model_save.save(model, filepath, overwrite, include_optimizer, saved_model_save.save(model, filepath, overwrite, include_optimizer,
signatures, options) signatures, options, save_traces)
@keras_export('keras.models.load_model') @keras_export('keras.models.load_model')

View File

@ -22,6 +22,7 @@ import abc
import six import six
from tensorflow.python.keras.saving.saved_model import json_utils from tensorflow.python.keras.saving.saved_model import json_utils
from tensorflow.python.keras.saving.saved_model import utils
from tensorflow.python.training.tracking import tracking from tensorflow.python.training.tracking import tracking
@ -71,6 +72,9 @@ class SavedModelSaver(object):
A dictionary mapping attribute names to trackable objects. The entire list A dictionary mapping attribute names to trackable objects. The entire list
of attributes are listed in the `saved_model._LayerAttributes` class. of attributes are listed in the `saved_model._LayerAttributes` class.
""" """
if not utils.should_save_traces():
return {}
return self.objects_to_serialize(serialization_cache) return self.objects_to_serialize(serialization_cache)
def list_functions_for_serialization(self, serialization_cache): def list_functions_for_serialization(self, serialization_cache):
@ -84,6 +88,9 @@ class SavedModelSaver(object):
A dictionary mapping attribute names to `Function` or A dictionary mapping attribute names to `Function` or
`ConcreteFunction`. `ConcreteFunction`.
""" """
if not utils.should_save_traces():
return {}
fns = self.functions_to_serialize(serialization_cache) fns = self.functions_to_serialize(serialization_cache)
# The parent AutoTrackable class saves all user-defined tf.functions, and # The parent AutoTrackable class saves all user-defined tf.functions, and

View File

@ -278,6 +278,16 @@ class KerasObjectLoader(object):
for name in PUBLIC_ATTRIBUTES: for name in PUBLIC_ATTRIBUTES:
delete_tracking(node, name) delete_tracking(node, name)
if isinstance(node, functional_lib.Functional):
# Delete the temporary layer dependencies, which were used to restore
# the checkpointed values. When the model is live, the user can delete
# or add layers to the model at any time, so these layer dependencies
# may be obsolete.
dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access
for name in dependencies:
if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None:
delete_tracking(node, name)
def _add_children_recreated_from_config(self, obj, proto, node_id): def _add_children_recreated_from_config(self, obj, proto, node_id):
"""Recursively records objects recreated from config.""" """Recursively records objects recreated from config."""
# pylint: disable=protected-access # pylint: disable=protected-access
@ -302,7 +312,7 @@ class KerasObjectLoader(object):
# This is stored in the SavedModel as layer.keras_api.layer_metrics in # This is stored in the SavedModel as layer.keras_api.layer_metrics in
# SavedModels created after Tf 2.2. # SavedModels created after Tf 2.2.
metric_list_node_id = self._search_for_child_node( metric_list_node_id = self._search_for_child_node(
node_id, [constants.KERAS_ATTR, 'layer_metrics'], raise_error=False) node_id, [constants.KERAS_ATTR, 'layer_metrics'])
if metric_list_node_id is not None and hasattr(obj, '_metrics'): if metric_list_node_id is not None and hasattr(obj, '_metrics'):
obj_metrics = {m.name: m for m in obj._metrics} obj_metrics = {m.name: m for m in obj._metrics}
for reference in self._proto.nodes[metric_list_node_id].children: for reference in self._proto.nodes[metric_list_node_id].children:
@ -384,8 +394,10 @@ class KerasObjectLoader(object):
config = metadata.get('config') config = metadata.get('config')
if _is_graph_network(node) and generic_utils.validate_config(config): if _is_graph_network(node) and generic_utils.validate_config(config):
self.model_layer_dependencies[node_id] = ( child_nodes = self._get_child_layer_node_ids(node_id)
node, self._get_child_layer_node_ids(node_id, node.name)) self.model_layer_dependencies[node_id] = (node, child_nodes)
if not child_nodes:
self._models_to_reconstruct.append(node_id)
return node, setter return node, setter
# Detect whether this object can be revived from the config. If not, then # Detect whether this object can be revived from the config. If not, then
@ -448,9 +460,10 @@ class KerasObjectLoader(object):
# Record this model and its layers. This will later be used to reconstruct # Record this model and its layers. This will later be used to reconstruct
# the model. # the model.
layers = self._get_child_layer_node_ids(node_id, model.name) layers = self._get_child_layer_node_ids(node_id)
self.model_layer_dependencies[node_id] = (model, layers) self.model_layer_dependencies[node_id] = (model, layers)
if not layers:
self._models_to_reconstruct.append(node_id)
return model return model
def _revive_layer_from_config(self, metadata, node_id): def _revive_layer_from_config(self, metadata, node_id):
@ -621,8 +634,14 @@ class KerasObjectLoader(object):
"""Reconstructs the network structure.""" """Reconstructs the network structure."""
config = json_utils.decode( config = json_utils.decode(
self._proto.nodes[model_id].user_object.metadata)['config'] self._proto.nodes[model_id].user_object.metadata)['config']
if isinstance(model, models_lib.Sequential):
if not isinstance(layers[0], input_layer.InputLayer): # Set up model inputs
if model.inputs:
# Inputs may already be created if the model is instantiated in another
# object's __init__.
pass
elif isinstance(model, models_lib.Sequential):
if not layers or not isinstance(layers[0], input_layer.InputLayer):
if config['layers'][0]['class_name'] == 'InputLayer': if config['layers'][0]['class_name'] == 'InputLayer':
layers.insert(0, input_layer.InputLayer.from_config( layers.insert(0, input_layer.InputLayer.from_config(
config['layers'][0]['config'])) config['layers'][0]['config']))
@ -635,13 +654,13 @@ class KerasObjectLoader(object):
name=layers[0].name + '_input')) name=layers[0].name + '_input'))
model.__init__(layers, name=config['name']) model.__init__(layers, name=config['name'])
if not model.inputs: if not model.inputs:
first_layer = self._get_child_layer_node_ids(model_id, model.name)[0] first_layer = self._get_child_layer_node_ids(model_id)[0]
input_specs = self._infer_inputs(first_layer) input_specs = self._infer_inputs(first_layer)
input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True) input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True)
model._set_inputs(input_specs) # pylint: disable=protected-access model._set_inputs(input_specs) # pylint: disable=protected-access
if not model.built and not isinstance(input_specs, dict): if not model.built and not isinstance(input_specs, dict):
model.build(input_shapes) model.build(input_shapes)
else: else: # Reconstruct functional model
(inputs, outputs, (inputs, outputs,
created_layers) = functional_lib.reconstruct_from_config( created_layers) = functional_lib.reconstruct_from_config(
config, created_layers={layer.name: layer for layer in layers}) config, created_layers={layer.name: layer for layer in layers})
@ -654,15 +673,31 @@ class KerasObjectLoader(object):
# Unblock models that are dependent on this model. # Unblock models that are dependent on this model.
self._unblock_model_reconstruction(model_id, model) self._unblock_model_reconstruction(model_id, model)
def _get_child_layer_node_ids(self, node_id, name): def _get_child_layer_node_ids(self, node_id):
"""Returns the node ids of the children layers of a node.""" """Returns the node ids of each layer in a Sequential/Functional model."""
# Retrieve the node id of layer.keras_api.layers. # Sequential and Functional track layers with names following the format
layer_list = self._search_for_child_node( # "layer-N". Use this to generate the list of layers.
node_id, [constants.KERAS_ATTR, 'layers'], name) num_layers = 0
return [node.node_id for node in self._proto.nodes[layer_list].children] child_layers = {}
pattern = re.compile('layer-(\\d+)')
def _search_for_child_node( for child in self._proto.nodes[node_id].children:
self, parent_id, path_to_child, debugging_name=None, raise_error=True): m = pattern.match(child.local_name)
if m is None:
continue
layer_n = int(m.group(1))
num_layers = max(layer_n + 1, num_layers)
child_layers[layer_n] = child.node_id
ordered = []
for n in range(num_layers):
child = child_layers.get(n)
if child is None:
break
ordered.append(child)
return ordered
def _search_for_child_node(self, parent_id, path_to_child):
"""Returns node id of child node. """Returns node id of child node.
A helper method for traversing the object graph proto. A helper method for traversing the object graph proto.
@ -680,37 +715,23 @@ class KerasObjectLoader(object):
Args: Args:
parent_id: node id of parent node parent_id: node id of parent node
path_to_child: list of children names. path_to_child: list of children names.
debugging_name: the name to print out when raising an error.
raise_error: Whether to raise an error if the child isn't found.
Returns: Returns:
node_id of child, or None if child isn't found. node_id of child, or None if child isn't found.
Raises:
ValueError: if child isn't found and raise_error is True.
""" """
if not path_to_child: if not path_to_child:
return parent_id return parent_id
for child in self._proto.nodes[parent_id].children: for child in self._proto.nodes[parent_id].children:
if child.local_name == path_to_child[0]: if child.local_name == path_to_child[0]:
return self._search_for_child_node(child.node_id, path_to_child[1:], return self._search_for_child_node(child.node_id, path_to_child[1:])
debugging_name, raise_error) return None
if raise_error:
raise ValueError(
'Error when loading {}: could not find attribute {}.\n'
'Most likely this object was serialized incorrectly.'
.format(debugging_name or path_to_child[0], path_to_child[0]))
else:
return None
def _infer_inputs(self, layer_node_id, convert_to_shapes=False): def _infer_inputs(self, layer_node_id, convert_to_shapes=False):
"""Infers input shape of layer from SavedModel functions.""" """Infers input shape of layer from SavedModel functions."""
coder = nested_structure_coder.StructureCoder() coder = nested_structure_coder.StructureCoder()
call_fn_id = self._search_for_child_node( call_fn_id = self._search_for_child_node(
layer_node_id, ['call_and_return_all_conditional_losses'], None, layer_node_id, ['call_and_return_all_conditional_losses'])
raise_error=False)
if call_fn_id is None: if call_fn_id is None:
return None return None
@ -797,9 +818,9 @@ def _unable_to_call_layer_due_to_serialization_issue(
""" """
raise ValueError( raise ValueError(
'Cannot call {} ({}), because the call function was not serialized to ' 'Cannot call custom layer {} of type {}, because the call function was '
'the SavedModel (due to lack information about the inputs). Please try ' 'not serialized to the SavedModel.'
'one of the following methods to fix the serialization:' 'Please try one of the following methods to fix this issue:'
'\n\n(1) Implement `get_config` and `from_config` in the layer/model ' '\n\n(1) Implement `get_config` and `from_config` in the layer/model '
'class, and pass the object to the `custom_objects` argument when ' 'class, and pass the object to the `custom_objects` argument when '
'loading the model. For more details, see: ' 'loading the model. For more details, see: '
@ -808,7 +829,7 @@ def _unable_to_call_layer_due_to_serialization_issue(
'and not `__call__`. The input shape and dtype will be automatically ' 'and not `__call__`. The input shape and dtype will be automatically '
'recorded when the object is called, and used when saving. To manually ' 'recorded when the object is called, and used when saving. To manually '
'specify the input shape/dtype, decorate the call function with ' 'specify the input shape/dtype, decorate the call function with '
'`@tf.function(input_signature=...)`.'.format(layer.name, layer)) '`@tf.function(input_signature=...)`.'.format(layer.name, type(layer)))
def _finalize_config_layers(layers): def _finalize_config_layers(layers):
@ -976,8 +997,11 @@ def _revive_setter(layer, name, value):
elif (isinstance(layer, functional_lib.Functional) and elif (isinstance(layer, functional_lib.Functional) and
re.match(r'^layer(_with_weights)?-[\d+]', name) is not None): re.match(r'^layer(_with_weights)?-[\d+]', name) is not None):
# Edges named "layer-n" or "layer_with_weights-n", which are tracked in # Edges named "layer-n" or "layer_with_weights-n", which are tracked in
# network._track_layers, should not be added as an attribute. # network._track_layers, should not be added as an attribute. They should
pass # be temporarily added as a dependency so that checkpointed values can be
# restored. These dependencies are manually deleted in
# KerasObjectLoader.del_tracking.
layer._track_trackable(value, name) # pylint: disable=protected-access
elif getattr(layer, name, None) is not None: elif getattr(layer, name, None) is not None:
# Don't overwrite already defined attributes. # Don't overwrite already defined attributes.
pass pass

View File

@ -22,6 +22,7 @@ from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.keras import backend as K from tensorflow.python.keras import backend as K
from tensorflow.python.keras.saving import saving_utils from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.saving.saved_model import save_impl from tensorflow.python.keras.saving.saved_model import save_impl
from tensorflow.python.keras.saving.saved_model import utils
from tensorflow.python.keras.utils.generic_utils import LazyLoader from tensorflow.python.keras.utils.generic_utils import LazyLoader
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.saved_model import save as save_lib from tensorflow.python.saved_model import save as save_lib
@ -38,7 +39,7 @@ training_lib = LazyLoader(
def save(model, filepath, overwrite, include_optimizer, signatures=None, def save(model, filepath, overwrite, include_optimizer, signatures=None,
options=None): options=None, save_traces=True):
"""Saves a model as a SavedModel to the filepath. """Saves a model as a SavedModel to the filepath.
Args: Args:
@ -49,8 +50,14 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
signatures: Signatures to save with the SavedModel. Applicable to the 'tf' signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
format only. Please see the `signatures` argument in `tf.saved_model.save` format only. Please see the `signatures` argument in `tf.saved_model.save`
for details. for details.
options: Optional `tf.saved_model.SaveOptions` object that specifies options: (only applies to SavedModel format) `tf.saved_model.SaveOptions`
options for saving to SavedModel. object that specifies options for saving to SavedModel.
save_traces: (only applies to SavedModel format) When enabled, the
SavedModel will store the function traces for each layer. This
can be disabled, so that only the configs of each layer are stored.
Defaults to `True`. Disabling this will decrease serialization time
and reduce file size, but it requires that all custom layers/models
implement a `get_config()` method.
Raises: Raises:
ValueError: if the model's inputs have not been defined. ValueError: if the model's inputs have not been defined.
@ -61,8 +68,9 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
if not proceed: if not proceed:
return return
if save_impl.should_skip_serialization(model): if save_traces:
saving_utils.raise_model_input_error(model) if save_impl.should_skip_serialization(model):
saving_utils.raise_model_input_error(model)
if not include_optimizer: if not include_optimizer:
orig_optimizer = model.optimizer orig_optimizer = model.optimizer
@ -77,7 +85,8 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
# the replica context is not available when calling `add_update()`, and thus # the replica context is not available when calling `add_update()`, and thus
# we use the default replica context here. # we use the default replica context here.
with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access
save_lib.save(model, filepath, signatures, options) with utils.keras_option_scope(save_traces):
save_lib.save(model, filepath, signatures, options)
if not include_optimizer: if not include_optimizer:
model.optimizer = orig_optimizer model.optimizer = orig_optimizer

View File

@ -114,7 +114,7 @@ class GlobalLayerThatShouldFailIfNotAdded(keras.layers.Layer):
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
class TestModelSavingAndLoadingV2(keras_parameterized.TestCase): class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
def _save_model_dir(self, dirname='saved_model'): def _save_model_dir(self, dirname='saved_model'):
temp_dir = self.get_temp_dir() temp_dir = self.get_temp_dir()
@ -829,6 +829,14 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
self.evaluate(variables.variables_initializer(loaded.variables)) self.evaluate(variables.variables_initializer(loaded.variables))
self.assertAllClose(model.predict(f), loaded.predict(f)) self.assertAllClose(model.predict(f), loaded.predict(f))
class TestSavedModelFormat(test.TestCase):
def _save_model_dir(self, dirname='saved_model'):
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
return os.path.join(temp_dir, dirname)
def test_load_with_partially_failed_serialization(self): def test_load_with_partially_failed_serialization(self):
class BadCustomLayer(keras.layers.Layer): class BadCustomLayer(keras.layers.Layer):
@ -858,6 +866,48 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
with self.assertRaisesRegex(ValueError, 'call function was not serialized'): with self.assertRaisesRegex(ValueError, 'call function was not serialized'):
loaded.layer(inp) loaded.layer(inp)
def test_save_without_tracing(self):
class DoNotTrace(keras.layers.Layer):
def __init__(self):
super(DoNotTrace, self).__init__()
self.input_spec = keras.layers.InputSpec(shape=[None])
self.built = True
def call(self, inputs):
raise ValueError('I said do not trace')
def get_config(self):
return {}
root = keras.models.Sequential()
root.add(keras.layers.Input(shape=(3,)))
root.attached_layer = DoNotTrace()
saved_model_dir = self._save_model_dir()
# With the default settings, the call function is traced.
with self.assertRaisesRegex(ValueError, 'do not trace'):
root.save(saved_model_dir, save_format='tf')
# When saving the config only, the layer call function should not be not
# traced.
root.save(saved_model_dir, save_format='tf', save_traces=False)
loaded = tf_load.load(saved_model_dir)
self.assertTrue(hasattr(loaded, 'attached_layer'))
# This should raise an error when loaded without the custom object
loaded = keras_load.load(saved_model_dir)
with self.assertRaisesRegex(ValueError, 'Cannot call custom layer'):
loaded.attached_layer(constant_op.constant([1.]))
# Try loading with the custom objects
with generic_utils.CustomObjectScope({'DoNotTrace': DoNotTrace}):
loaded = keras_load.load(saved_model_dir)
with self.assertRaisesRegex(ValueError, 'I said do not trace'):
loaded.attached_layer(constant_op.constant([1.]))
class TestLayerCallTracing(test.TestCase, parameterized.TestCase): class TestLayerCallTracing(test.TestCase, parameterized.TestCase):

View File

@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import itertools import itertools
import threading
import types import types
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -25,6 +26,7 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import base_layer_utils from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.utils import control_flow_util from tensorflow.python.keras.utils import control_flow_util
from tensorflow.python.keras.utils import layer_utils from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.keras.utils import tf_contextlib
from tensorflow.python.keras.utils import tf_inspect from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.keras.utils.generic_utils import LazyLoader from tensorflow.python.keras.utils.generic_utils import LazyLoader
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
@ -245,3 +247,27 @@ def remove_training_arg(index, args, kwargs):
args.pop(index) args.pop(index)
else: else:
kwargs.pop('training', None) kwargs.pop('training', None)
class SaveOptionsContext(threading.local):
def __init__(self):
super(SaveOptionsContext, self).__init__()
self.save_traces = True
_save_options_context = SaveOptionsContext()
@tf_contextlib.contextmanager
def keras_option_scope(save_traces):
previous_value = _save_options_context.save_traces
try:
_save_options_context.save_traces = save_traces
yield
finally:
_save_options_context.save_traces = previous_value
def should_save_traces():
return _save_options_context.save_traces

View File

@ -304,6 +304,7 @@ _thread_local_data = threading.local()
_thread_local_data.model_type = None _thread_local_data.model_type = None
_thread_local_data.run_eagerly = None _thread_local_data.run_eagerly = None
_thread_local_data.saved_model_format = None _thread_local_data.saved_model_format = None
_thread_local_data.save_kwargs = None
@tf_contextlib.contextmanager @tf_contextlib.contextmanager
@ -383,7 +384,7 @@ def should_run_eagerly():
@tf_contextlib.contextmanager @tf_contextlib.contextmanager
def saved_model_format_scope(value): def saved_model_format_scope(value, **kwargs):
"""Provides a scope within which the savde model format to test is `value`. """Provides a scope within which the savde model format to test is `value`.
The saved model format gets restored to its original value upon exiting the The saved model format gets restored to its original value upon exiting the
@ -391,17 +392,21 @@ def saved_model_format_scope(value):
Arguments: Arguments:
value: saved model format value value: saved model format value
**kwargs: optional kwargs to pass to the save function.
Yields: Yields:
The provided value. The provided value.
""" """
previous_value = _thread_local_data.saved_model_format previous_format = _thread_local_data.saved_model_format
previous_kwargs = _thread_local_data.save_kwargs
try: try:
_thread_local_data.saved_model_format = value _thread_local_data.saved_model_format = value
yield value _thread_local_data.save_kwargs = kwargs
yield
finally: finally:
# Restore saved model format to initial value. # Restore saved model format to initial value.
_thread_local_data.saved_model_format = previous_value _thread_local_data.saved_model_format = previous_format
_thread_local_data.save_kwargs = previous_kwargs
def get_save_format(): def get_save_format():
@ -413,6 +418,15 @@ def get_save_format():
return _thread_local_data.saved_model_format return _thread_local_data.saved_model_format
def get_save_kwargs():
if _thread_local_data.save_kwargs is None:
raise ValueError(
'Cannot call `get_save_kwargs()` outside of a '
'`saved_model_format_scope()` or `run_with_all_saved_model_formats` '
'decorator.')
return _thread_local_data.save_kwargs or {}
def get_model_type(): def get_model_type():
"""Gets the model type that should be tested.""" """Gets the model type that should be tested."""
if _thread_local_data.model_type is None: if _thread_local_data.model_type is None:

View File

@ -310,7 +310,7 @@ tf_class {
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "save_weights" name: "save_weights"

View File

@ -328,7 +328,7 @@ tf_class {
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "save_weights" name: "save_weights"

View File

@ -311,7 +311,7 @@ tf_class {
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "save_weights" name: "save_weights"

View File

@ -311,7 +311,7 @@ tf_class {
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "save_weights" name: "save_weights"

View File

@ -310,7 +310,7 @@ tf_class {
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "save_weights" name: "save_weights"

View File

@ -328,7 +328,7 @@ tf_class {
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "save_weights" name: "save_weights"

View File

@ -30,6 +30,6 @@ tf_module {
} }
member_method { member_method {
name: "save_model" name: "save_model"
argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
} }

View File

@ -310,7 +310,7 @@ tf_class {
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "save_weights" name: "save_weights"

View File

@ -328,7 +328,7 @@ tf_class {
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "save_weights" name: "save_weights"

View File

@ -311,7 +311,7 @@ tf_class {
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "save_weights" name: "save_weights"

View File

@ -311,7 +311,7 @@ tf_class {
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "save_weights" name: "save_weights"

View File

@ -310,7 +310,7 @@ tf_class {
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "save_weights" name: "save_weights"

View File

@ -328,7 +328,7 @@ tf_class {
} }
member_method { member_method {
name: "save" name: "save"
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
member_method { member_method {
name: "save_weights" name: "save_weights"

View File

@ -30,6 +30,6 @@ tf_module {
} }
member_method { member_method {
name: "save_model" name: "save_model"
argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], " argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
} }
} }