(rollforward) Add option to not save the traces when exporting to the SavedModel format.
The current tracing implementation has a limited scope of models and layers that can be traced. When users add a custom layer or model that is unsupported (e.g. with multiple tensor arguments), they'll come across an error that prevents them from saving the model entirely. With this option, those models can now be saved to the SavedModel format for serving or retraining. PiperOrigin-RevId: 338295627 Change-Id: Ieea88ecaa1b8665df4ab45c96e882867e4308d88
This commit is contained in:
parent
7d57263720
commit
259ffa9ea6
@ -270,6 +270,12 @@
|
||||
* For Keras model, the individual call of `Model.evaluate` uses no cached
|
||||
data for evaluation, while `Model.fit` uses cached data when
|
||||
`validation_data` arg is provided for better performance.
|
||||
* Added a `save_traces` argument to `model.save`/
|
||||
`tf.keras.models.save_model` which determines whether the SavedModel
|
||||
format stores the Keras model/layer call functions. The traced functions
|
||||
allow Keras to revive custom models and layers without the original
|
||||
class definition, but if this isn't required the tracing can be
|
||||
disabled with the added option.
|
||||
* `tf.function` / AutoGraph:
|
||||
* Added `experimental_follow_type_hints` argument for `tf.function`. When
|
||||
True, the function may use type annotations to optimize the tracing
|
||||
|
@ -2104,9 +2104,10 @@ class Layer(base_layer.Layer):
|
||||
# operations.
|
||||
with tf_utils.maybe_init_scope(self):
|
||||
self.build(input_shapes)
|
||||
# We must set self.built since user defined build functions are not
|
||||
# constrained to set self.built.
|
||||
self.built = True
|
||||
# We must set also ensure that the layer is marked as built, and the build
|
||||
# shape is stored since user defined build functions may not be calling
|
||||
# `super.build()`
|
||||
Layer.build(self, input_shapes)
|
||||
|
||||
# Optionally load weight values specified at layer instantiation.
|
||||
if self._initial_weights is not None:
|
||||
|
@ -1953,31 +1953,14 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
include_optimizer=True,
|
||||
save_format=None,
|
||||
signatures=None,
|
||||
options=None):
|
||||
options=None,
|
||||
save_traces=True):
|
||||
# pylint: disable=line-too-long
|
||||
"""Saves the model to Tensorflow SavedModel or a single HDF5 file.
|
||||
|
||||
The savefile includes:
|
||||
|
||||
- The model architecture, allowing to re-instantiate the model.
|
||||
- The model weights.
|
||||
- The state of the optimizer, allowing to resume training
|
||||
exactly where you left off.
|
||||
|
||||
This allows you to save the entirety of the state of a model
|
||||
in a single file.
|
||||
|
||||
Saved models can be re-instantiated via `keras.models.load_model`.
|
||||
The model returned by `load_model` is a compiled model ready to be used
|
||||
(unless the saved model was never compiled in the first place).
|
||||
|
||||
Models built with the Sequential and Functional API can be saved to both the
|
||||
HDF5 and SavedModel formats. Subclassed models can only be saved with the
|
||||
SavedModel format.
|
||||
|
||||
Note that the model weights may have different scoped names after being
|
||||
loaded. Scoped names include the model/layer names, such as
|
||||
`"dense_1/kernel:0"`. It is recommended that you use the layer properties to
|
||||
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
|
||||
Please see `tf.keras.models.save_model` or the
|
||||
[Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/)
|
||||
for details.
|
||||
|
||||
Arguments:
|
||||
filepath: String, PathLike, path to SavedModel or H5 file to save the
|
||||
@ -1991,8 +1974,15 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
signatures: Signatures to save with the SavedModel. Applicable to the
|
||||
'tf' format only. Please see the `signatures` argument in
|
||||
`tf.saved_model.save` for details.
|
||||
options: Optional `tf.saved_model.SaveOptions` object that specifies
|
||||
options for saving to SavedModel.
|
||||
options: (only applies to SavedModel format)
|
||||
`tf.saved_model.SaveOptions` object that specifies options for
|
||||
saving to SavedModel.
|
||||
save_traces: (only applies to SavedModel format) When enabled, the
|
||||
SavedModel will store the function traces for each layer. This
|
||||
can be disabled, so that only the configs of each layer are stored.
|
||||
Defaults to `True`. Disabling this will decrease serialization time
|
||||
and reduce file size, but it requires that all custom layers/models
|
||||
implement a `get_config()` method.
|
||||
|
||||
Example:
|
||||
|
||||
@ -2007,8 +1997,9 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
|
||||
model = load_model('my_model.h5')
|
||||
```
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
save.save_model(self, filepath, overwrite, include_optimizer, save_format,
|
||||
signatures, options)
|
||||
signatures, options, save_traces)
|
||||
|
||||
def save_weights(self,
|
||||
filepath,
|
||||
|
@ -113,7 +113,6 @@ def run_with_all_saved_model_formats(
|
||||
tf.test.main()
|
||||
```
|
||||
|
||||
|
||||
Args:
|
||||
test_or_class: test method or class to be annotated. If None,
|
||||
this method returns a decorator that can be applied to a test method or
|
||||
@ -134,7 +133,7 @@ def run_with_all_saved_model_formats(
|
||||
# Exclude h5 save format if H5py isn't available.
|
||||
if h5py is None:
|
||||
exclude_formats.append(['h5'])
|
||||
saved_model_formats = ['h5', 'tf']
|
||||
saved_model_formats = ['h5', 'tf', 'tf_no_traces']
|
||||
params = [('_%s' % saved_format, saved_format)
|
||||
for saved_format in saved_model_formats
|
||||
if saved_format not in nest.flatten(exclude_formats)]
|
||||
@ -150,6 +149,8 @@ def run_with_all_saved_model_formats(
|
||||
_test_h5_saved_model_format(f, self, *args, **kwargs)
|
||||
elif saved_format == 'tf':
|
||||
_test_tf_saved_model_format(f, self, *args, **kwargs)
|
||||
elif saved_format == 'tf_no_traces':
|
||||
_test_tf_saved_model_format_no_traces(f, self, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError('Unknown model type: %s' % (saved_format,))
|
||||
return decorated
|
||||
@ -167,6 +168,18 @@ def _test_tf_saved_model_format(f, test_or_class, *args, **kwargs):
|
||||
f(test_or_class, *args, **kwargs)
|
||||
|
||||
|
||||
def _test_tf_saved_model_format_no_traces(f, test_or_class, *args, **kwargs):
|
||||
with testing_utils.saved_model_format_scope('tf', save_traces=False):
|
||||
f(test_or_class, *args, **kwargs)
|
||||
|
||||
|
||||
def run_with_all_weight_formats(test_or_class=None, exclude_formats=None):
|
||||
"""Runs all tests with the supported formats for saving weights."""
|
||||
exclude_formats = exclude_formats or []
|
||||
exclude_formats.append('tf_no_traces') # Only applies to saving models
|
||||
return run_with_all_saved_model_formats(test_or_class, exclude_formats)
|
||||
|
||||
|
||||
# TODO(kaftan): Possibly enable 'subclass_custom_build' when tests begin to pass
|
||||
# it. Or perhaps make 'subclass' always use a custom build method.
|
||||
def run_with_all_model_types(
|
||||
|
@ -58,7 +58,7 @@ except ImportError:
|
||||
@combinations.generate(combinations.combine(mode=['graph', 'eager']))
|
||||
class TestWeightSavingAndLoading(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@keras_parameterized.run_with_all_saved_model_formats
|
||||
@keras_parameterized.run_with_all_weight_formats
|
||||
def test_weight_loading(self):
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir)
|
||||
@ -410,9 +410,14 @@ class TestWholeModelSaving(keras_parameterized.TestCase):
|
||||
def test_save_and_load(self):
|
||||
saved_model_dir = self._save_model_dir()
|
||||
save_format = testing_utils.get_save_format()
|
||||
save_kwargs = testing_utils.get_save_kwargs()
|
||||
|
||||
if save_format == 'h5' and testing_utils.get_model_type() == 'subclass':
|
||||
return # HDF5 format currently does not allow saving classed models.
|
||||
if ((save_format == 'h5' or not save_kwargs.get('save_traces', True)) and
|
||||
testing_utils.get_model_type() == 'subclass'):
|
||||
# HDF5 format currently does not allow saving subclassed models.
|
||||
# When saving with `save_traces=False`, the subclassed model must have a
|
||||
# get_config/from_config, which the autogenerated model does not have.
|
||||
return
|
||||
|
||||
with self.cached_session():
|
||||
model = testing_utils.get_model_from_layers(
|
||||
@ -440,7 +445,9 @@ class TestWholeModelSaving(keras_parameterized.TestCase):
|
||||
model.train_on_batch(x, y)
|
||||
|
||||
out = model.predict(x)
|
||||
keras.models.save_model(model, saved_model_dir, save_format=save_format)
|
||||
keras.models.save_model(
|
||||
model, saved_model_dir, save_format=save_format,
|
||||
**save_kwargs)
|
||||
|
||||
loaded_model = keras.models.load_model(saved_model_dir)
|
||||
self._assert_same_weights_and_metrics(model, loaded_model)
|
||||
|
@ -52,9 +52,14 @@ def save_model(model,
|
||||
include_optimizer=True,
|
||||
save_format=None,
|
||||
signatures=None,
|
||||
options=None):
|
||||
options=None,
|
||||
save_traces=True):
|
||||
# pylint: disable=line-too-long
|
||||
"""Saves a model as a TensorFlow SavedModel or HDF5 file.
|
||||
|
||||
See the [Serialization and Saving guide](https://keras.io/guides/serialization_and_saving/)
|
||||
for details.
|
||||
|
||||
Usage:
|
||||
|
||||
>>> model = tf.keras.Sequential([
|
||||
@ -65,28 +70,38 @@ def save_model(model,
|
||||
>>> x = tf.random.uniform((10, 3))
|
||||
>>> assert np.allclose(model.predict(x), loaded_model.predict(x))
|
||||
|
||||
The saved model contains:
|
||||
The SavedModel and HDF5 file contains:
|
||||
|
||||
- the model's configuration (topology)
|
||||
- the model's weights
|
||||
- the model's optimizer's state (if any)
|
||||
|
||||
Thus the saved model can be reinstantiated in
|
||||
the exact same state, without any of the code
|
||||
used for model definition or training.
|
||||
Thus models can be reinstantiated in the exact same state, without any of the
|
||||
code used for model definition or training.
|
||||
|
||||
Note that the model weights may have different scoped names after being
|
||||
loaded. Scoped names include the model/layer names, such as
|
||||
`"dense_1/kernel:0"`. It is recommended that you use the layer properties to
|
||||
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
|
||||
|
||||
_SavedModel serialization_
|
||||
__SavedModel serialization format__
|
||||
|
||||
The SavedModel serialization path uses `tf.saved_model.save` to save the model
|
||||
and all trackable objects attached to the model (e.g. layers and variables).
|
||||
`@tf.function`-decorated methods are also saved. Additional trackable objects
|
||||
and functions are added to the SavedModel to allow the model to be
|
||||
loaded back as a Keras Model object.
|
||||
Keras SavedModel uses `tf.saved_model.save` to save the model and all
|
||||
trackable objects attached to the model (e.g. layers and variables). The model
|
||||
config, weights, and optimizer are saved in the SavedModel. Additionally, for
|
||||
every Keras layer attached to the model, the SavedModel stores:
|
||||
|
||||
* the config and metadata -- e.g. name, dtype, trainable status
|
||||
* traced call and loss functions, which are stored as TensorFlow subgraphs.
|
||||
|
||||
The traced functions allow the SavedModel format to save and load custom
|
||||
layers without the original class definition.
|
||||
|
||||
You can choose to not save the traced functions by disabling the `save_traces`
|
||||
option. This will decrease the time it takes to save the model and the
|
||||
amount of disk space occupied by the output SavedModel. If you enable this
|
||||
option, then you _must_ provide all custom class definitions when loading
|
||||
the model. See the `custom_objects` argument in `tf.keras.models.load_model`.
|
||||
|
||||
Arguments:
|
||||
model: Keras model instance to be saved.
|
||||
@ -102,12 +117,19 @@ def save_model(model,
|
||||
signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
|
||||
format only. Please see the `signatures` argument in
|
||||
`tf.saved_model.save` for details.
|
||||
options: Optional `tf.saved_model.SaveOptions` object that specifies
|
||||
options for saving to SavedModel.
|
||||
options: (only applies to SavedModel format) `tf.saved_model.SaveOptions`
|
||||
object that specifies options for saving to SavedModel.
|
||||
save_traces: (only applies to SavedModel format) When enabled, the
|
||||
SavedModel will store the function traces for each layer. This
|
||||
can be disabled, so that only the configs of each layer are stored.
|
||||
Defaults to `True`. Disabling this will decrease serialization time and
|
||||
reduce file size, but it requires that all custom layers/models
|
||||
implement a `get_config()` method.
|
||||
|
||||
Raises:
|
||||
ImportError: If save format is hdf5, and h5py is not available.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top
|
||||
|
||||
default_format = 'tf' if tf2.enabled() else 'h5'
|
||||
@ -132,7 +154,7 @@ def save_model(model,
|
||||
model, filepath, overwrite, include_optimizer)
|
||||
else:
|
||||
saved_model_save.save(model, filepath, overwrite, include_optimizer,
|
||||
signatures, options)
|
||||
signatures, options, save_traces)
|
||||
|
||||
|
||||
@keras_export('keras.models.load_model')
|
||||
|
@ -22,6 +22,7 @@ import abc
|
||||
import six
|
||||
|
||||
from tensorflow.python.keras.saving.saved_model import json_utils
|
||||
from tensorflow.python.keras.saving.saved_model import utils
|
||||
from tensorflow.python.training.tracking import tracking
|
||||
|
||||
|
||||
@ -71,6 +72,9 @@ class SavedModelSaver(object):
|
||||
A dictionary mapping attribute names to trackable objects. The entire list
|
||||
of attributes are listed in the `saved_model._LayerAttributes` class.
|
||||
"""
|
||||
if not utils.should_save_traces():
|
||||
return {}
|
||||
|
||||
return self.objects_to_serialize(serialization_cache)
|
||||
|
||||
def list_functions_for_serialization(self, serialization_cache):
|
||||
@ -84,6 +88,9 @@ class SavedModelSaver(object):
|
||||
A dictionary mapping attribute names to `Function` or
|
||||
`ConcreteFunction`.
|
||||
"""
|
||||
if not utils.should_save_traces():
|
||||
return {}
|
||||
|
||||
fns = self.functions_to_serialize(serialization_cache)
|
||||
|
||||
# The parent AutoTrackable class saves all user-defined tf.functions, and
|
||||
|
@ -278,6 +278,16 @@ class KerasObjectLoader(object):
|
||||
for name in PUBLIC_ATTRIBUTES:
|
||||
delete_tracking(node, name)
|
||||
|
||||
if isinstance(node, functional_lib.Functional):
|
||||
# Delete the temporary layer dependencies, which were used to restore
|
||||
# the checkpointed values. When the model is live, the user can delete
|
||||
# or add layers to the model at any time, so these layer dependencies
|
||||
# may be obsolete.
|
||||
dependencies = list(node._self_unconditional_dependency_names) # pylint: disable=protected-access
|
||||
for name in dependencies:
|
||||
if re.match(r'^layer(_with_weights)?-[\d+]', name) is not None:
|
||||
delete_tracking(node, name)
|
||||
|
||||
def _add_children_recreated_from_config(self, obj, proto, node_id):
|
||||
"""Recursively records objects recreated from config."""
|
||||
# pylint: disable=protected-access
|
||||
@ -302,7 +312,7 @@ class KerasObjectLoader(object):
|
||||
# This is stored in the SavedModel as layer.keras_api.layer_metrics in
|
||||
# SavedModels created after Tf 2.2.
|
||||
metric_list_node_id = self._search_for_child_node(
|
||||
node_id, [constants.KERAS_ATTR, 'layer_metrics'], raise_error=False)
|
||||
node_id, [constants.KERAS_ATTR, 'layer_metrics'])
|
||||
if metric_list_node_id is not None and hasattr(obj, '_metrics'):
|
||||
obj_metrics = {m.name: m for m in obj._metrics}
|
||||
for reference in self._proto.nodes[metric_list_node_id].children:
|
||||
@ -384,8 +394,10 @@ class KerasObjectLoader(object):
|
||||
|
||||
config = metadata.get('config')
|
||||
if _is_graph_network(node) and generic_utils.validate_config(config):
|
||||
self.model_layer_dependencies[node_id] = (
|
||||
node, self._get_child_layer_node_ids(node_id, node.name))
|
||||
child_nodes = self._get_child_layer_node_ids(node_id)
|
||||
self.model_layer_dependencies[node_id] = (node, child_nodes)
|
||||
if not child_nodes:
|
||||
self._models_to_reconstruct.append(node_id)
|
||||
return node, setter
|
||||
|
||||
# Detect whether this object can be revived from the config. If not, then
|
||||
@ -448,9 +460,10 @@ class KerasObjectLoader(object):
|
||||
|
||||
# Record this model and its layers. This will later be used to reconstruct
|
||||
# the model.
|
||||
layers = self._get_child_layer_node_ids(node_id, model.name)
|
||||
layers = self._get_child_layer_node_ids(node_id)
|
||||
self.model_layer_dependencies[node_id] = (model, layers)
|
||||
|
||||
if not layers:
|
||||
self._models_to_reconstruct.append(node_id)
|
||||
return model
|
||||
|
||||
def _revive_layer_from_config(self, metadata, node_id):
|
||||
@ -621,8 +634,14 @@ class KerasObjectLoader(object):
|
||||
"""Reconstructs the network structure."""
|
||||
config = json_utils.decode(
|
||||
self._proto.nodes[model_id].user_object.metadata)['config']
|
||||
if isinstance(model, models_lib.Sequential):
|
||||
if not isinstance(layers[0], input_layer.InputLayer):
|
||||
|
||||
# Set up model inputs
|
||||
if model.inputs:
|
||||
# Inputs may already be created if the model is instantiated in another
|
||||
# object's __init__.
|
||||
pass
|
||||
elif isinstance(model, models_lib.Sequential):
|
||||
if not layers or not isinstance(layers[0], input_layer.InputLayer):
|
||||
if config['layers'][0]['class_name'] == 'InputLayer':
|
||||
layers.insert(0, input_layer.InputLayer.from_config(
|
||||
config['layers'][0]['config']))
|
||||
@ -635,13 +654,13 @@ class KerasObjectLoader(object):
|
||||
name=layers[0].name + '_input'))
|
||||
model.__init__(layers, name=config['name'])
|
||||
if not model.inputs:
|
||||
first_layer = self._get_child_layer_node_ids(model_id, model.name)[0]
|
||||
first_layer = self._get_child_layer_node_ids(model_id)[0]
|
||||
input_specs = self._infer_inputs(first_layer)
|
||||
input_shapes = self._infer_inputs(first_layer, convert_to_shapes=True)
|
||||
model._set_inputs(input_specs) # pylint: disable=protected-access
|
||||
if not model.built and not isinstance(input_specs, dict):
|
||||
model.build(input_shapes)
|
||||
else:
|
||||
else: # Reconstruct functional model
|
||||
(inputs, outputs,
|
||||
created_layers) = functional_lib.reconstruct_from_config(
|
||||
config, created_layers={layer.name: layer for layer in layers})
|
||||
@ -654,15 +673,31 @@ class KerasObjectLoader(object):
|
||||
# Unblock models that are dependent on this model.
|
||||
self._unblock_model_reconstruction(model_id, model)
|
||||
|
||||
def _get_child_layer_node_ids(self, node_id, name):
|
||||
"""Returns the node ids of the children layers of a node."""
|
||||
# Retrieve the node id of layer.keras_api.layers.
|
||||
layer_list = self._search_for_child_node(
|
||||
node_id, [constants.KERAS_ATTR, 'layers'], name)
|
||||
return [node.node_id for node in self._proto.nodes[layer_list].children]
|
||||
def _get_child_layer_node_ids(self, node_id):
|
||||
"""Returns the node ids of each layer in a Sequential/Functional model."""
|
||||
# Sequential and Functional track layers with names following the format
|
||||
# "layer-N". Use this to generate the list of layers.
|
||||
num_layers = 0
|
||||
child_layers = {}
|
||||
pattern = re.compile('layer-(\\d+)')
|
||||
|
||||
def _search_for_child_node(
|
||||
self, parent_id, path_to_child, debugging_name=None, raise_error=True):
|
||||
for child in self._proto.nodes[node_id].children:
|
||||
m = pattern.match(child.local_name)
|
||||
if m is None:
|
||||
continue
|
||||
layer_n = int(m.group(1))
|
||||
num_layers = max(layer_n + 1, num_layers)
|
||||
child_layers[layer_n] = child.node_id
|
||||
|
||||
ordered = []
|
||||
for n in range(num_layers):
|
||||
child = child_layers.get(n)
|
||||
if child is None:
|
||||
break
|
||||
ordered.append(child)
|
||||
return ordered
|
||||
|
||||
def _search_for_child_node(self, parent_id, path_to_child):
|
||||
"""Returns node id of child node.
|
||||
|
||||
A helper method for traversing the object graph proto.
|
||||
@ -680,37 +715,23 @@ class KerasObjectLoader(object):
|
||||
Args:
|
||||
parent_id: node id of parent node
|
||||
path_to_child: list of children names.
|
||||
debugging_name: the name to print out when raising an error.
|
||||
raise_error: Whether to raise an error if the child isn't found.
|
||||
|
||||
Returns:
|
||||
node_id of child, or None if child isn't found.
|
||||
|
||||
Raises:
|
||||
ValueError: if child isn't found and raise_error is True.
|
||||
"""
|
||||
if not path_to_child:
|
||||
return parent_id
|
||||
|
||||
for child in self._proto.nodes[parent_id].children:
|
||||
if child.local_name == path_to_child[0]:
|
||||
return self._search_for_child_node(child.node_id, path_to_child[1:],
|
||||
debugging_name, raise_error)
|
||||
|
||||
if raise_error:
|
||||
raise ValueError(
|
||||
'Error when loading {}: could not find attribute {}.\n'
|
||||
'Most likely this object was serialized incorrectly.'
|
||||
.format(debugging_name or path_to_child[0], path_to_child[0]))
|
||||
else:
|
||||
return None
|
||||
return self._search_for_child_node(child.node_id, path_to_child[1:])
|
||||
return None
|
||||
|
||||
def _infer_inputs(self, layer_node_id, convert_to_shapes=False):
|
||||
"""Infers input shape of layer from SavedModel functions."""
|
||||
coder = nested_structure_coder.StructureCoder()
|
||||
call_fn_id = self._search_for_child_node(
|
||||
layer_node_id, ['call_and_return_all_conditional_losses'], None,
|
||||
raise_error=False)
|
||||
layer_node_id, ['call_and_return_all_conditional_losses'])
|
||||
if call_fn_id is None:
|
||||
return None
|
||||
|
||||
@ -797,9 +818,9 @@ def _unable_to_call_layer_due_to_serialization_issue(
|
||||
"""
|
||||
|
||||
raise ValueError(
|
||||
'Cannot call {} ({}), because the call function was not serialized to '
|
||||
'the SavedModel (due to lack information about the inputs). Please try '
|
||||
'one of the following methods to fix the serialization:'
|
||||
'Cannot call custom layer {} of type {}, because the call function was '
|
||||
'not serialized to the SavedModel.'
|
||||
'Please try one of the following methods to fix this issue:'
|
||||
'\n\n(1) Implement `get_config` and `from_config` in the layer/model '
|
||||
'class, and pass the object to the `custom_objects` argument when '
|
||||
'loading the model. For more details, see: '
|
||||
@ -808,7 +829,7 @@ def _unable_to_call_layer_due_to_serialization_issue(
|
||||
'and not `__call__`. The input shape and dtype will be automatically '
|
||||
'recorded when the object is called, and used when saving. To manually '
|
||||
'specify the input shape/dtype, decorate the call function with '
|
||||
'`@tf.function(input_signature=...)`.'.format(layer.name, layer))
|
||||
'`@tf.function(input_signature=...)`.'.format(layer.name, type(layer)))
|
||||
|
||||
|
||||
def _finalize_config_layers(layers):
|
||||
@ -976,8 +997,11 @@ def _revive_setter(layer, name, value):
|
||||
elif (isinstance(layer, functional_lib.Functional) and
|
||||
re.match(r'^layer(_with_weights)?-[\d+]', name) is not None):
|
||||
# Edges named "layer-n" or "layer_with_weights-n", which are tracked in
|
||||
# network._track_layers, should not be added as an attribute.
|
||||
pass
|
||||
# network._track_layers, should not be added as an attribute. They should
|
||||
# be temporarily added as a dependency so that checkpointed values can be
|
||||
# restored. These dependencies are manually deleted in
|
||||
# KerasObjectLoader.del_tracking.
|
||||
layer._track_trackable(value, name) # pylint: disable=protected-access
|
||||
elif getattr(layer, name, None) is not None:
|
||||
# Don't overwrite already defined attributes.
|
||||
pass
|
||||
|
@ -22,6 +22,7 @@ from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.saving.saved_model import save_impl
|
||||
from tensorflow.python.keras.saving.saved_model import utils
|
||||
from tensorflow.python.keras.utils.generic_utils import LazyLoader
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.saved_model import save as save_lib
|
||||
@ -38,7 +39,7 @@ training_lib = LazyLoader(
|
||||
|
||||
|
||||
def save(model, filepath, overwrite, include_optimizer, signatures=None,
|
||||
options=None):
|
||||
options=None, save_traces=True):
|
||||
"""Saves a model as a SavedModel to the filepath.
|
||||
|
||||
Args:
|
||||
@ -49,8 +50,14 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
|
||||
signatures: Signatures to save with the SavedModel. Applicable to the 'tf'
|
||||
format only. Please see the `signatures` argument in `tf.saved_model.save`
|
||||
for details.
|
||||
options: Optional `tf.saved_model.SaveOptions` object that specifies
|
||||
options for saving to SavedModel.
|
||||
options: (only applies to SavedModel format) `tf.saved_model.SaveOptions`
|
||||
object that specifies options for saving to SavedModel.
|
||||
save_traces: (only applies to SavedModel format) When enabled, the
|
||||
SavedModel will store the function traces for each layer. This
|
||||
can be disabled, so that only the configs of each layer are stored.
|
||||
Defaults to `True`. Disabling this will decrease serialization time
|
||||
and reduce file size, but it requires that all custom layers/models
|
||||
implement a `get_config()` method.
|
||||
|
||||
Raises:
|
||||
ValueError: if the model's inputs have not been defined.
|
||||
@ -61,8 +68,9 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
|
||||
if not proceed:
|
||||
return
|
||||
|
||||
if save_impl.should_skip_serialization(model):
|
||||
saving_utils.raise_model_input_error(model)
|
||||
if save_traces:
|
||||
if save_impl.should_skip_serialization(model):
|
||||
saving_utils.raise_model_input_error(model)
|
||||
|
||||
if not include_optimizer:
|
||||
orig_optimizer = model.optimizer
|
||||
@ -77,7 +85,8 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
|
||||
# the replica context is not available when calling `add_update()`, and thus
|
||||
# we use the default replica context here.
|
||||
with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access
|
||||
save_lib.save(model, filepath, signatures, options)
|
||||
with utils.keras_option_scope(save_traces):
|
||||
save_lib.save(model, filepath, signatures, options)
|
||||
|
||||
if not include_optimizer:
|
||||
model.optimizer = orig_optimizer
|
||||
|
@ -114,7 +114,7 @@ class GlobalLayerThatShouldFailIfNotAdded(keras.layers.Layer):
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
|
||||
def _save_model_dir(self, dirname='saved_model'):
|
||||
temp_dir = self.get_temp_dir()
|
||||
@ -829,6 +829,14 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
self.assertAllClose(model.predict(f), loaded.predict(f))
|
||||
|
||||
|
||||
class TestSavedModelFormat(test.TestCase):
|
||||
|
||||
def _save_model_dir(self, dirname='saved_model'):
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
|
||||
return os.path.join(temp_dir, dirname)
|
||||
|
||||
def test_load_with_partially_failed_serialization(self):
|
||||
|
||||
class BadCustomLayer(keras.layers.Layer):
|
||||
@ -858,6 +866,48 @@ class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'call function was not serialized'):
|
||||
loaded.layer(inp)
|
||||
|
||||
def test_save_without_tracing(self):
|
||||
|
||||
class DoNotTrace(keras.layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(DoNotTrace, self).__init__()
|
||||
self.input_spec = keras.layers.InputSpec(shape=[None])
|
||||
self.built = True
|
||||
|
||||
def call(self, inputs):
|
||||
raise ValueError('I said do not trace')
|
||||
|
||||
def get_config(self):
|
||||
return {}
|
||||
|
||||
root = keras.models.Sequential()
|
||||
root.add(keras.layers.Input(shape=(3,)))
|
||||
root.attached_layer = DoNotTrace()
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
|
||||
# With the default settings, the call function is traced.
|
||||
with self.assertRaisesRegex(ValueError, 'do not trace'):
|
||||
root.save(saved_model_dir, save_format='tf')
|
||||
|
||||
# When saving the config only, the layer call function should not be not
|
||||
# traced.
|
||||
root.save(saved_model_dir, save_format='tf', save_traces=False)
|
||||
loaded = tf_load.load(saved_model_dir)
|
||||
self.assertTrue(hasattr(loaded, 'attached_layer'))
|
||||
|
||||
# This should raise an error when loaded without the custom object
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
with self.assertRaisesRegex(ValueError, 'Cannot call custom layer'):
|
||||
loaded.attached_layer(constant_op.constant([1.]))
|
||||
|
||||
# Try loading with the custom objects
|
||||
with generic_utils.CustomObjectScope({'DoNotTrace': DoNotTrace}):
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
with self.assertRaisesRegex(ValueError, 'I said do not trace'):
|
||||
loaded.attached_layer(constant_op.constant([1.]))
|
||||
|
||||
|
||||
class TestLayerCallTracing(test.TestCase, parameterized.TestCase):
|
||||
|
||||
|
@ -18,6 +18,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
import threading
|
||||
import types
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
@ -25,6 +26,7 @@ from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.utils import control_flow_util
|
||||
from tensorflow.python.keras.utils import layer_utils
|
||||
from tensorflow.python.keras.utils import tf_contextlib
|
||||
from tensorflow.python.keras.utils import tf_inspect
|
||||
from tensorflow.python.keras.utils.generic_utils import LazyLoader
|
||||
from tensorflow.python.util import tf_decorator
|
||||
@ -245,3 +247,27 @@ def remove_training_arg(index, args, kwargs):
|
||||
args.pop(index)
|
||||
else:
|
||||
kwargs.pop('training', None)
|
||||
|
||||
|
||||
class SaveOptionsContext(threading.local):
|
||||
|
||||
def __init__(self):
|
||||
super(SaveOptionsContext, self).__init__()
|
||||
self.save_traces = True
|
||||
|
||||
|
||||
_save_options_context = SaveOptionsContext()
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def keras_option_scope(save_traces):
|
||||
previous_value = _save_options_context.save_traces
|
||||
try:
|
||||
_save_options_context.save_traces = save_traces
|
||||
yield
|
||||
finally:
|
||||
_save_options_context.save_traces = previous_value
|
||||
|
||||
|
||||
def should_save_traces():
|
||||
return _save_options_context.save_traces
|
||||
|
@ -304,6 +304,7 @@ _thread_local_data = threading.local()
|
||||
_thread_local_data.model_type = None
|
||||
_thread_local_data.run_eagerly = None
|
||||
_thread_local_data.saved_model_format = None
|
||||
_thread_local_data.save_kwargs = None
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
@ -383,7 +384,7 @@ def should_run_eagerly():
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def saved_model_format_scope(value):
|
||||
def saved_model_format_scope(value, **kwargs):
|
||||
"""Provides a scope within which the savde model format to test is `value`.
|
||||
|
||||
The saved model format gets restored to its original value upon exiting the
|
||||
@ -391,17 +392,21 @@ def saved_model_format_scope(value):
|
||||
|
||||
Arguments:
|
||||
value: saved model format value
|
||||
**kwargs: optional kwargs to pass to the save function.
|
||||
|
||||
Yields:
|
||||
The provided value.
|
||||
"""
|
||||
previous_value = _thread_local_data.saved_model_format
|
||||
previous_format = _thread_local_data.saved_model_format
|
||||
previous_kwargs = _thread_local_data.save_kwargs
|
||||
try:
|
||||
_thread_local_data.saved_model_format = value
|
||||
yield value
|
||||
_thread_local_data.save_kwargs = kwargs
|
||||
yield
|
||||
finally:
|
||||
# Restore saved model format to initial value.
|
||||
_thread_local_data.saved_model_format = previous_value
|
||||
_thread_local_data.saved_model_format = previous_format
|
||||
_thread_local_data.save_kwargs = previous_kwargs
|
||||
|
||||
|
||||
def get_save_format():
|
||||
@ -413,6 +418,15 @@ def get_save_format():
|
||||
return _thread_local_data.saved_model_format
|
||||
|
||||
|
||||
def get_save_kwargs():
|
||||
if _thread_local_data.save_kwargs is None:
|
||||
raise ValueError(
|
||||
'Cannot call `get_save_kwargs()` outside of a '
|
||||
'`saved_model_format_scope()` or `run_with_all_saved_model_formats` '
|
||||
'decorator.')
|
||||
return _thread_local_data.save_kwargs or {}
|
||||
|
||||
|
||||
def get_model_type():
|
||||
"""Gets the model type that should be tested."""
|
||||
if _thread_local_data.model_type is None:
|
||||
|
@ -310,7 +310,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
|
@ -328,7 +328,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
|
@ -311,7 +311,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
|
@ -311,7 +311,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
|
@ -310,7 +310,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
|
@ -328,7 +328,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
|
@ -30,6 +30,6 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "save_model"
|
||||
argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
}
|
||||
|
@ -310,7 +310,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
|
@ -328,7 +328,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
|
@ -311,7 +311,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
|
@ -311,7 +311,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
|
@ -310,7 +310,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
|
@ -328,7 +328,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "save"
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "save_weights"
|
||||
|
@ -30,6 +30,6 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "save_model"
|
||||
argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'model\', \'filepath\', \'overwrite\', \'include_optimizer\', \'save_format\', \'signatures\', \'options\', \'save_traces\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'None\', \'None\', \'None\', \'True\'], "
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user