Keras models and layers saving and reviving code. Implements go/tf-model-serialization.

To save and revive a model:
1. Save the model using tf.saved_model.save
2. call load_from_save_model_v2

This restores various metadata about Keras models and layers, as well as their call and loss functions.

Changes to object serialization:
- Adds private fields for tracking object's identifier and metadata.
- Added _list_extra_dependencies_for_serialization, which allows objects to save extra
  dependencies when serialized to SavedModel.
- Object graph view maintains a serialization cache object that is passed to each object when serializing functions/extra dependencies.

PiperOrigin-RevId: 251386039
This commit is contained in:
Katherine Wu 2019-06-04 00:28:15 -07:00 committed by TensorFlower Gardener
parent ece5314ddb
commit eff4ae822a
27 changed files with 1751 additions and 206 deletions

View File

@ -74,6 +74,8 @@ message SavedUserObject {
string identifier = 1;
// Version information from the producer of this SavedUserObject.
VersionDef version = 2;
// Initialization-related metadata.
string metadata = 3;
}
// A SavedAsset points to an asset in the MetaGraph.

View File

@ -160,7 +160,6 @@ py_library(
"engine/__init__.py",
"engine/base_layer.py",
"engine/input_layer.py",
"engine/input_spec.py",
"engine/network.py",
"engine/partial_batch_padding_handler.py",
"engine/saving.py",
@ -185,6 +184,7 @@ py_library(
":constraints",
":engine_utils",
":initializers",
":input_spec",
":losses",
":mode_keys",
":optimizers",
@ -210,6 +210,20 @@ py_library(
],
)
py_library(
name = "input_spec",
srcs = ["engine/input_spec.py"],
srcs_version = "PY2AND3",
deps = [
":backend",
"//tensorflow/python:dtypes",
"//tensorflow/python:lib",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"@six_archive//:six",
],
)
py_library(
name = "saving",
srcs = [
@ -224,12 +238,18 @@ py_library(
deps = [
":backend",
":engine_utils",
":input_spec",
":mode_keys",
":optimizers",
":regularizers",
"//tensorflow/python:lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:saver",
"//tensorflow/python:tensor_spec",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/saved_model",
"//tensorflow/python/saved_model/model_utils",
"//tensorflow/python/training/tracking",
],
)
@ -1239,6 +1259,17 @@ tf_py_test(
tags = ["notsan"],
)
tf_py_test(
name = "input_spec_test",
size = "small",
srcs = ["engine/input_spec_test.py"],
additional_deps = [
":keras",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:client_testlib",
],
)
tf_py_test(
name = "training_test",
size = "medium",
@ -1511,6 +1542,7 @@ tf_py_test(
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
],
shard_count = 4,
tags = [
"no_oss", # TODO(b/119349471): Re-enable
"no_windows",

View File

@ -22,6 +22,7 @@ import collections
import functools
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
import itertools
import json
import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
@ -46,9 +47,11 @@ from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.keras.saving import saved_model
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
# A module that only depends on `keras.layers` import these from here.
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import
from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import
from tensorflow.python.module import module
@ -64,6 +67,7 @@ from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export
@ -2131,6 +2135,109 @@ class Layer(module.Module):
return ('mask' in self._call_fn_args or
getattr(self, 'compute_mask', None) is not None)
@property
def _object_identifier(self):
"""String stored in object identifier field in the SavedModel proto.
Returns:
A string with the object identifier, which is used at load time.
"""
return '_tf_keras_layer'
@property
def _tracking_metadata(self):
"""String stored in metadata field in the SavedModel proto.
Returns:
A serialized JSON storing information necessary for recreating this layer.
"""
# TODO(kathywu): Add support for metrics serialization.
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
# the python config serialization has caught up.
# Create a dictionary containing python layer state attributes. Any new
# attribute that impacts the layer execution in some way should be added to
# this dict.
# Unlike a model's JSON configuration, which only
# contains class_name and each layer's get_config() object, this stores more
# information to accurately recreate the layer.
# For backwards compatibility, any changes to this list should be additive.
# Modifying or removing attributes may only be done with a sufficient
# explanation.
metadata = dict(
class_name=type(self).__name__,
name=self.name,
trainable=self.trainable,
expects_training_arg=self._expects_training_arg,
dtype=self.dtype,
batch_input_shape=getattr(self, '_batch_input_shape', None))
try:
# Store the config dictionary, which is only used by the revived object
# to return the original config when revived_obj.get_config() is called.
# It is not important for recreating the revived object.
metadata['config'] = self.get_config()
except NotImplementedError:
# in the case of a subclassed model, the get_config() method will throw
# a NotImplementedError.
pass
if self.input_spec is not None:
metadata['input_spec'] = nest.map_structure(
lambda x: x.get_config(), self.input_spec)
else:
metadata['input_spec'] = None
if (self.activity_regularizer is not None and
hasattr(self.activity_regularizer, 'get_config')):
metadata['activity_regularizer'] = serialize_keras_object(
self.activity_regularizer)
else:
metadata['activity_regularizer'] = None
return json.dumps(metadata, default=serialization.get_json_type)
def _list_extra_dependencies_for_serialization(self, serialization_cache):
"""Lists extra dependencies to serialize to SavedModel.
By overriding this method, extra dependencies can be attached to the
serialized Layer. For example, this is used to save the list of `variables`
and `trainable_variables`, which are python properties in a Layer object,
but are represented as a static list in the SavedModel.
Args:
serialization_cache: A dictionary shared between all objects in the same
object graph. This object is passed to both
`_list_extra_dependencies_for_serialization` and
`_list_functions_for_serialization`.
Returns:
A dictionary mapping attribute names to trackable objects. The entire list
of attributes are listed in the `saved_model._LayerAttributes` class.
"""
return (saved_model.serialize_all_attributes(self, serialization_cache)
.objects_to_serialize)
def _list_functions_for_serialization(self, serialization_cache):
"""Lists the functions to include when serializing a Layer.
Args:
serialization_cache: Dictionary passed to all objects in the same object
graph during serialization.
Returns:
A dictionary mapping attribute names to `Function` or
`ConcreteFunction`. The entire list of attributes are listed in the
`saved_model._LayerAttributes` class.
"""
# Create a dictionary containing the layer's call and loss functions.
fns = (saved_model.serialize_all_attributes(self, serialization_cache)
.functions_to_serialize)
# The parent Autotrackable class saves all user-defined tf.functions, and
# returns them in _list_functions_for_serialization(). Add these functions
# to the dict.
fns.update(super(Layer, self)._list_functions_for_serialization(
serialization_cache))
return fns
class Node(object):
"""A `Node` describes the connectivity between two layers.

View File

@ -289,6 +289,9 @@ def is_in_eager_or_tf_function():
def is_in_tf_function():
"""Returns if inside of a tf.function."""
# Check if running in V1 graph mode.
if not ops.executing_eagerly_outside_functions():
return False
if not ops.inside_function():
return False
# Check if inside Keras FuncGraph.

View File

@ -20,12 +20,16 @@ from __future__ import print_function
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.util.tf_export import tf_export
@keras_export('keras.layers.InputSpec', v1=['keras.layers.InputSpec'])
@keras_export('keras.layers.InputSpec')
@tf_export(v1=['layers.InputSpec'])
class InputSpec(object):
"""Specifies the ndim, dtype and shape of every input to a layer.
@ -54,15 +58,27 @@ class InputSpec(object):
max_ndim=None,
min_ndim=None,
axes=None):
self.dtype = dtype
self.shape = shape
self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
if shape is not None:
self.ndim = len(shape)
self.shape = shape
else:
self.ndim = ndim
self.shape = None
self.max_ndim = max_ndim
self.min_ndim = min_ndim
self.axes = axes or {}
try:
axes = axes or {}
self.axes = {int(k): axes[k] for k in axes}
except (ValueError, TypeError):
raise TypeError('The keys in axes must be integers.')
if self.axes and (self.ndim is not None or self.max_ndim is not None):
max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
max_axis = max(self.axes)
if max_axis > max_dim:
raise ValueError('Axis {} is greater than the maximum allowed value: {}'
.format(max_axis, max_dim))
def __repr__(self):
spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
@ -73,6 +89,42 @@ class InputSpec(object):
('axes=' + str(self.axes)) if self.axes else '']
return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
def get_config(self):
return {
'dtype': self.dtype,
'shape': self.shape,
'ndim': self.ndim,
'max_ndim': self.max_ndim,
'min_ndim': self.min_ndim,
'axes': self.axes}
@classmethod
def from_config(cls, config):
return cls(**config)
def to_tensor_shape(spec):
"""Returns a tf.TensorShape object that matches the shape specifications.
If the InputSpec's shape or ndim is defined, this method will return a fully
or partially-known shape. Otherwise, the returned TensorShape is None.
Args:
spec: an InputSpec object.
Returns:
a tf.TensorShape object
"""
if spec.ndim is None and spec.shape is None:
return tensor_shape.TensorShape(None)
elif spec.shape is not None:
return tensor_shape.TensorShape(spec.shape)
else:
shape = [None] * spec.ndim
for a in spec.axes:
shape[a] = spec.axes[a] # Assume that axes is defined
return tensor_shape.TensorShape(shape)
def assert_input_compatibility(input_spec, inputs, layer_name):
"""Checks compatibility between the layer and provided inputs.
@ -168,3 +220,12 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
' is incompatible with layer ' + layer_name +
': expected shape=' + str(spec.shape) +
', found shape=' + str(shape))
def to_tensor_spec(input_spec, default_dtype=None):
"""Converts a Keras InputSpec object to a TensorSpec."""
default_dtype = default_dtype or backend.floatx()
if isinstance(input_spec, InputSpec):
dtype = input_spec.dtype or default_dtype
return tensor_spec.TensorSpec(to_tensor_shape(input_spec), dtype)
return tensor_spec.TensorSpec(None, default_dtype)

View File

@ -0,0 +1,66 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""InputSpec tests."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras.engine import input_spec
from tensorflow.python.platform import test
class InputSpecTest(test.TestCase):
def test_axes_initialization(self):
input_spec.InputSpec(shape=[1, None, 2, 3], axes={3: 5, '2': 2})
with self.assertRaisesRegexp(ValueError, 'Axis 4 is greater than'):
input_spec.InputSpec(shape=[1, None, 2, 3], axes={4: 5})
with self.assertRaisesRegexp(TypeError, 'keys in axes must be integers'):
input_spec.InputSpec(shape=[1, None, 2, 3], axes={'string': 5})
class InputSpecToTensorShapeTest(test.TestCase):
def test_defined_shape(self):
spec = input_spec.InputSpec(shape=[1, None, 2, 3])
self.assertAllEqual(
[1, None, 2, 3], input_spec.to_tensor_shape(spec).as_list())
def test_defined_ndims(self):
spec = input_spec.InputSpec(ndim=5)
self.assertAllEqual(
[None] * 5, input_spec.to_tensor_shape(spec).as_list())
spec = input_spec.InputSpec(ndim=0)
self.assertAllEqual(
[], input_spec.to_tensor_shape(spec).as_list())
spec = input_spec.InputSpec(ndim=3, axes={1: 3, -1: 2})
self.assertAllEqual(
[None, 3, 2], input_spec.to_tensor_shape(spec).as_list())
def test_undefined_shapes(self):
spec = input_spec.InputSpec(max_ndim=5)
with self.assertRaisesRegexp(ValueError, 'unknown TensorShape'):
input_spec.to_tensor_shape(spec).as_list()
spec = input_spec.InputSpec(min_ndim=5, max_ndim=5)
with self.assertRaisesRegexp(ValueError, 'unknown TensorShape'):
input_spec.to_tensor_shape(spec).as_list()
if __name__ == '__main__':
test.main()

View File

@ -1663,6 +1663,10 @@ class Network(base_layer.Layer):
'inputs or `build()` is called with an `input_shape`.' %
self.name)
@property
def _object_identifier(self):
return '_tf_keras_network'
def _is_hdf5_filepath(filepath):
return (filepath.endswith('.h5') or filepath.endswith('.keras') or

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
import json
import numpy as np
from tensorflow.python import tf2
@ -55,6 +56,7 @@ from tensorflow.python.ops.losses import util as tf_losses_utils
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
from tensorflow.python.util.tf_export import keras_export
@ -1707,26 +1709,6 @@ class Model(network.Network):
batch_size = 32
return batch_size
def _list_functions_for_serialization(self):
"""If available, saves a trace of call using self.inputs."""
all_functions = super(Model, self)._list_functions_for_serialization()
try:
# pylint:disable=pointless-statement
self.inputs
self.input_names
# pylint:enable=pointless-statement
except AttributeError:
# If the model does not have inputs set, because it was not called or its
# input shapes were not recorded, we won't have a signature so can't trace
# a function. But the user may still save an object with this Model
# attached; we won't fail the whole tf.saved_model.save.
pass
else:
if '_default_save_signature' not in all_functions:
all_functions['_default_save_signature'] = (
saving_utils.trace_model_call(self))
return all_functions
def _prepare_sample_weights(self, sample_weights=None):
"""Sets sample weight attribute on the model."""
# List with the same length as model outputs.
@ -2717,6 +2699,17 @@ class Model(network.Network):
initial_epoch, mode)
return initial_epoch
@property
def _object_identifier(self):
return '_tf_keras_model'
@property
def _tracking_metadata(self):
metadata = json.loads(super(Model, self)._tracking_metadata)
metadata.update(saving_utils.model_metadata(
self, include_optimizer=True, require_config=False))
return json.dumps(metadata, default=serialization.get_json_type)
def _assert_compile_was_called(self):
# Checks whether `compile` has been called. If it has been called,
# then the optimizer is set. This is different from whether the

View File

@ -1010,7 +1010,7 @@ def _get_slot_key_from_var(var, slot_name):
return name + "/" + slot_name
class _RestoredOptimizer(OptimizerV2):
class RestoredOptimizer(OptimizerV2):
"""A non-functional Optimizer implementation for checkpoint compatibility.
Holds slot variables and hyperparameters when an optimizer is restored from a
@ -1022,7 +1022,7 @@ class _RestoredOptimizer(OptimizerV2):
# methods.
def __init__(self):
super(_RestoredOptimizer, self).__init__("_RestoredOptimizer")
super(RestoredOptimizer, self).__init__("RestoredOptimizer")
self._hypers_created = True
def get_config(self):
@ -1036,9 +1036,9 @@ revived_types.register_revived_type(
"optimizer",
lambda obj: isinstance(obj, OptimizerV2),
versions=[revived_types.VersionedTypeRegistration(
object_factory=lambda proto: _RestoredOptimizer(),
object_factory=lambda proto: RestoredOptimizer(),
version=1,
min_producer_version=1,
min_consumer_version=1,
setter=_RestoredOptimizer._set_hyper # pylint: disable=protected-access
setter=RestoredOptimizer._set_hyper # pylint: disable=protected-access
)])

View File

@ -28,6 +28,7 @@ from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.saving import model_config as model_config_lib
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging
@ -71,8 +72,6 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
if h5py is None:
raise ImportError('`save_model` requires h5py.')
from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
# TODO(psv) Add warning when we save models that contain non-serializable
# entities like metrics added using `add_metric` and losses added using
# `add_loss.`
@ -91,48 +90,21 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
opened_new_file = False
try:
f.attrs['keras_version'] = str(keras_version).encode('utf8')
f.attrs['backend'] = K.backend().encode('utf8')
f.attrs['model_config'] = json.dumps(
{
'class_name': model.__class__.__name__,
'config': model.get_config()
},
default=serialization.get_json_type).encode('utf8')
model_metadata = saving_utils.model_metadata(model, include_optimizer)
for k, v in model_metadata.items():
f.attrs[k] = json.dumps(
v, default=serialization.get_json_type).encode('utf8')
model_weights_group = f.create_group('model_weights')
model_layers = model.layers
save_weights_to_hdf5_group(model_weights_group, model_layers)
if include_optimizer and model.optimizer:
if isinstance(model.optimizer, optimizers.TFOptimizer):
logging.warning(
'TensorFlow optimizers do not '
'make it possible to access '
'optimizer attributes or optimizer state '
'after instantiation. '
'As a result, we cannot save the optimizer '
'as part of the model save file. '
'You will have to compile your model again after loading it. '
'Prefer using a Keras optimizer instead '
'(see keras.io/optimizers).')
else:
f.attrs['training_config'] = json.dumps(
{
'optimizer_config': {
'class_name': model.optimizer.__class__.__name__,
'config': model.optimizer.get_config()
},
'loss': model.loss,
'metrics': model._compile_metrics,
'weighted_metrics': model._compile_weighted_metrics,
'sample_weight_mode': model.sample_weight_mode,
'loss_weights': model.loss_weights,
},
default=serialization.get_json_type).encode('utf8')
# TODO(b/128683857): Add integration tests between tf.keras and external
# Keras, to avoid breaking TF.js users.
if (include_optimizer and model.optimizer and
not isinstance(model.optimizer, optimizers.TFOptimizer)):
save_optimizer_weights_to_hdf5_group(f, model.optimizer)
# Save optimizer weights.
save_optimizer_weights_to_hdf5_group(f, model.optimizer)
f.flush()
finally:
if opened_new_file:

View File

@ -60,6 +60,14 @@ def save_model(model,
the exact same state, without any of the code
used for model definition or training.
_SavedModel serialization_ (not yet added)
The SavedModel serialization path uses `tf.saved_model.save` to save the model
and all trackable objects attached to the model (e.g. layers and variables).
`@tf.function`-decorated methods are also saved. Additional trackable objects
and functions are added to the SavedModel to allow the model to be
loaded back as a Keras Model object.
Arguments:
model: Keras model instance to be saved.
filepath: One of the following:
@ -147,7 +155,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
if isinstance(filepath, six.string_types):
loader_impl.parse_saved_model(filepath)
return saved_model.load_from_saved_model(filepath)
return saved_model.load_from_saved_model_v2(filepath)
raise IOError(
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '

View File

@ -12,37 +12,77 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Utility functions to save/load keras Model to/from SavedModel."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import json
import os
import six
from tensorflow.python.client import session
from tensorflow.python.eager import def_function
from tensorflow.python.eager import function as defun
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import optimizers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.engine import input_spec
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving import model_from_json
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import mode_keys
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import model_utils
from tensorflow.python.saved_model import save as save_lib
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import graph_view
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
from tensorflow.python.training.tracking.tracking import AutoTrackable
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import keras_export
# To avoid circular dependencies between keras/engine and keras/saving,
# code in keras/saving must delay imports.
# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
# once the issue with copybara is fixed.
# pylint:disable=g-inconsistent-quotes
metrics_lib = LazyLoader("metrics_lib", globals(),
"tensorflow.python.keras.metrics")
models_lib = LazyLoader("models_lib", globals(),
"tensorflow.python.keras.models")
base_layer = LazyLoader(
"base_layer", globals(),
"tensorflow.python.keras.engine.base_layer")
network_lib = LazyLoader(
"network_lib", globals(),
"tensorflow.python.keras.engine.network")
sequential = LazyLoader(
"sequential", globals(),
"tensorflow.python.keras.engine.sequential")
training_lib = LazyLoader(
"training_lib", globals(),
"tensorflow.python.keras.engine.training")
# pylint:enable=g-inconsistent-quotes
@keras_export('keras.experimental.export_saved_model')
def export_saved_model(model,
@ -144,9 +184,7 @@ def _export_model_variables(model, saved_model_path):
def _save_v1_format(model, path, custom_objects, as_text, input_signature):
"""Exports model to v1 SavedModel format."""
from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top
if not model._is_graph_network:
if not model._is_graph_network: # pylint: disable=protected-access
if isinstance(model, sequential.Sequential):
# If input shape is not directly set in the model, the exported model
# will infer the expected shapes of the input from the model.
@ -163,7 +201,7 @@ def _save_v1_format(model, path, custom_objects, as_text, input_signature):
'Subclassed models can only be exported for serving. Please set '
'argument serving_only=True.')
builder = saved_model_builder._SavedModelBuilder(path)
builder = saved_model_builder._SavedModelBuilder(path) # pylint: disable=protected-access
# Manually save variables to export them in an object-based checkpoint. This
# skips the `builder.add_meta_graph_and_variables()` step, which saves a
@ -233,7 +271,6 @@ def _export_mode(
ValueError: If the train/eval mode is being exported, but the model does
not have an optimizer.
"""
from tensorflow.python.keras import models as models_lib # pylint: disable=g-import-not-at-top
compile_clone = (mode != mode_keys.ModeKeys.PREDICT)
if compile_clone and not model.optimizer:
raise ValueError(
@ -264,12 +301,12 @@ def _export_mode(
# Extract update and train ops from train/test/predict functions.
train_op = None
if mode == mode_keys.ModeKeys.TRAIN:
clone._make_train_function()
clone._make_train_function() # pylint: disable=protected-access
train_op = clone.train_function.updates_op
elif mode == mode_keys.ModeKeys.TEST:
clone._make_test_function()
clone._make_test_function() # pylint: disable=protected-access
else:
clone._make_predict_function()
clone._make_predict_function() # pylint: disable=protected-access
g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
with session.Session().as_default():
@ -292,7 +329,7 @@ def _export_mode(
# Add graph and variables to SavedModel.
# TODO(b/113134168): Switch to add_meta_graph_and_variables.
clone.save_weights(checkpoint_path, save_format='tf', overwrite=True)
builder._has_saved_variables = True
builder._has_saved_variables = True # pylint: disable=protected-access
# Add graph to the SavedModel builder.
builder.add_meta_graph(
@ -313,7 +350,7 @@ def _create_signature_def_map(model, mode):
inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)}
if model.optimizer:
targets_dict = {x.name.split(':')[0]: x
for x in model._targets if x is not None}
for x in model._targets if x is not None} # pylint: disable=protected-access
inputs_dict.update(targets_dict)
outputs_dict = {name: x
for name, x in zip(model.output_names, model.outputs)}
@ -325,9 +362,8 @@ def _create_signature_def_map(model, mode):
local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
vars_to_add = set()
if metrics is not None:
from tensorflow.python.keras.metrics import Metric # pylint: disable=g-import-not-at-top
for key, value in six.iteritems(metrics):
if isinstance(value, Metric):
if isinstance(value, metrics_lib.Metric):
vars_to_add.update(value.variables)
# Convert Metric instances to (value_tensor, update_op) tuple.
metrics[key] = (value.result(), value.updates[0])
@ -406,3 +442,823 @@ def load_from_saved_model(saved_model_path, custom_objects=None):
compat.as_text(constants.VARIABLES_FILENAME))
model.load_weights(checkpoint_prefix)
return model
################################################################################
# Functional Style/V2 SavedModel functions #
################################################################################
# All serialized attributes are listed within SerializedAttributes classes. See
# the docstring in SerializedAttributes for more context
# All attributes are saved under the 'keras_api' namespace. Only common
# endpoints are attached directly to the root object.
_KERAS_ATTR = 'keras_api'
# Keys for the serialization cache.
# Maps to the keras serialization dict {Layer --> SerializedAttributes object}
_KERAS_CACHE_KEY = 'keras_serialized_attributes'
class SerializedAttributes(object):
"""Class that tracks and validates all serialization attributes.
Keras models contain many Python-defined components. For example, the
trainable_variable property lists the model's trainable variables by
recursively retrieving the trainable variables from each of the child layers.
Another example is model.call, a python function that calls child layers and
adds ops to the backend graph.
Only Tensorflow checkpointable objects and functions can be serialized to
SavedModel. Serializing a Keras model as-is results in a checkpointable object
that does not resemble a Keras model at all. Thus, extra checkpointable
objects and functions must be created during serialization.
**Defining new serialized attributes**
Child classes should be defined using:
SerializedAttributes.with_attributes(
'name', checkpointable_objects=[...], functions=[...], copy_from=[...])
This class is used to cache generated checkpointable objects and functions,
ensuring that new objects and functions are generated a single time.
**Usage during serialization**
Each Layer/Model object should have a corresponding instance of
SerializedAttributes. Create a new instance by calling
`SerializedAttributes.new(obj)`. Objects and functions may be saved using
`.set_and_validate_checkpointable_objects`/`.set_and_and_validate_functions`.
The properties `.checkpointable_objects` and `.functions` returns the cached
values.
**Adding/changing attributes to save to SavedModel**
1. Change the call to `SerializedAttributes.with_attributes` in the correct
class:
- CommonEndpoints: Base attributes to be added during serialization. If
these attributes are present in a Trackable object, it can be
deserialized to a Keras Model.
- LayerAttributes: Attributes to serialize for Layer objects.
- ModelAttributes: Attributes to serialize for Model objects.
2. Update class docstring
3. Update arguments to any calls to `set_and_validate_*`. For example, if
`call_raw_tensors` is added to the ModelAttributes function list, then
a `call_raw_tensors` function should be passed to
`set_and_validate_functions`.
**Common endpoints vs other attributes**
Only common endpoints are attached directly to the root object. Keras-specific
attributes are saved to a separate trackable object with the name "keras_api".
The number of objects attached to the root is limited because any naming
conflicts will cause user code to break.
Another reason is that this will only affect users who call
`tf.saved_model.load` instead of `tf.keras.models.load_model`. These are
advanced users who are likely to have defined their own tf.functions and
trackable objects. The added Keras-specific attributes are kept out of the way
in the "keras_api" namespace.
Properties defined in this class may be used to filter out keras-specific
attributes:
- `functions_to_serialize`: Returns dict of functions to attach to the root
object.
- `checkpointable_objects_to_serialize`: Returns dict of objects to attach to
the root object (including separate trackable object containing
keras-specific attributes)
All changes to the serialized attributes must be backwards-compatible, so
attributes should not be removed or modified without sufficient justification.
"""
@staticmethod
def with_attributes(
name, checkpointable_objects=None, functions=None, copy_from=None):
"""Creates a subclass with all attributes as specified in the arguments.
Args:
name: Name of subclass
checkpointable_objects: List of checkpointable objects to be serialized
in the SavedModel.
functions: List of functions to be serialized in the SavedModel.
copy_from: List of other SerializedAttributes subclasses. The returend
class will copy checkpoint objects/functions from each subclass.
Returns:
Child class with attributes as defined in the `checkpointable_objects`
and `functions` lists.
"""
checkpointable_objects = checkpointable_objects or []
functions = functions or []
if copy_from is not None:
for cls in copy_from:
checkpointable_objects.extend(cls.all_checkpointable_objects)
functions.extend(cls.all_functions)
classdict = {
'all_checkpointable_objects': set(checkpointable_objects),
'all_functions': set(functions)}
return type(name, (SerializedAttributes,), classdict)
@staticmethod
def new(obj):
if isinstance(obj, training_lib.Model):
return ModelAttributes()
elif isinstance(obj, base_layer.Layer):
return LayerAttributes()
else:
raise TypeError('Internal error during serialization: Expected Keras '
'Layer object, got {} of type {}'.format(obj, type(obj)))
def __init__(self):
self._object_dict = {}
self._function_dict = {}
self._keras_trackable = AutoTrackable()
@property
def functions(self):
"""Returns dictionary of all functions."""
return {key: value for key, value in self._function_dict.items()
if value is not None}
@property
def checkpointable_objects(self):
"""Returns dictionary of all checkpointable objects."""
return {key: value for key, value in self._object_dict.items()
if value is not None}
@property
def functions_to_serialize(self):
"""Returns functions to attach to the root object during serialization."""
return {key: value for key, value in self.functions.items()
if key in CommonEndpoints.all_functions}
@property
def objects_to_serialize(self):
"""Returns objects to attach to the root object during serialization."""
objects = {key: value for key, value in self.checkpointable_objects.items()
if key in CommonEndpoints.all_checkpointable_objects}
objects[_KERAS_ATTR] = self._keras_trackable
return objects
def set_and_validate_functions(self, function_dict):
"""Saves function dictionary, and validates dictionary values."""
for key in self.all_functions:
if key in function_dict:
if (function_dict[key] is not None and # Not all functions are required
not isinstance(function_dict[key],
(defun.Function, def_function.Function))):
raise ValueError(
'Function dictionary contained a non-function object: {} (for key'
' {})'.format(function_dict[key], key))
self._function_dict[key] = function_dict[key]
setattr(self._keras_trackable, key, function_dict[key])
else:
raise ValueError('Function {} missing from serialized function dict.'
.format(key))
return self.functions
def set_and_validate_objects(self, object_dict):
"""Saves objects to a dictionary, and validates the values."""
for key in self.all_checkpointable_objects:
if key in object_dict:
if not isinstance(object_dict[key], trackable.Trackable):
raise ValueError(
'Object dictionary contained a non-trackable object: {} (for key'
' {})'.format(object_dict[key], key))
self._object_dict[key] = object_dict[key]
setattr(self._keras_trackable, key, object_dict[key])
else:
raise ValueError('Object {} missing from serialized object dict.')
return self.checkpointable_objects
class CommonEndpoints(SerializedAttributes.with_attributes(
'CommonEndpoints',
checkpointable_objects=['variables', 'trainable_variables',
'regularization_losses'],
functions=['__call__', 'call_and_return_all_conditional_losses',
'_default_save_signature'])):
"""Common endpoints shared by all models loadable by Keras.
List of all attributes:
variables: List of all variables in the model and its sublayers.
trainable_variables: List of all trainable variables in the model and its
sublayers.
regulariation_losses: List of all unconditional losses (losses not dependent
on the inputs) in the model and its sublayers.
__call__: Function that takes inputs and returns the outputs of the model
call function.
call_and_return_all_conditional_losses: Function that returns a tuple of
(call function outputs, list of all losses that depend on the inputs).
_default_save_signature: Traced model call function. This is only included
if the top level exported object is a Keras model.
"""
class LayerAttributes(SerializedAttributes.with_attributes(
'LayerAttributes',
checkpointable_objects=['non_trainable_variables', 'layers', 'metrics'],
functions=['call_and_return_conditional_losses', 'activity_regularizer_fn'],
copy_from=[CommonEndpoints]
)):
"""Layer checkpointable objects + functions that are saved to the SavedModel.
List of all attributes:
All attributes from CommonEndpoints
non_trainable_variables: List of non-trainable variables in the layer and
its sublayers.
layers: List of all sublayers.
metrics: List of all metrics in the layer and its sublayers.
call_and_return_conditional_losses: Function that takes inputs and returns a
tuple of (outputs of the call function, list of input-dependent losses).
The list of losses excludes the activity regularizer function, which is
separate to allow the deserialized Layer object to define a different
activity regularizer.
activity_regularizer_fn: Callable that returns the activity regularizer loss
"""
class ModelAttributes(SerializedAttributes.with_attributes(
'ModelAttributes',
copy_from=[LayerAttributes])):
"""Model checkpointable objects + functions that are saved to the SavedModel.
List of all attributes:
All attributes from LayerAttributes (including CommonEndpoints)
"""
# TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, which
# list all losses and metrics defined by `model.compile`.
def serialize_all_attributes(layer, serialization_cache):
"""Serialize all attributes in the layer."""
save_model_default_signature = False
if _KERAS_CACHE_KEY not in serialization_cache:
keras_cache = serialization_cache[_KERAS_CACHE_KEY] = {}
if isinstance(layer, training_lib.Model):
# Only trace default signature if the root object is a Model. Since the
# keras cache key is only created in this method, we know that the object
# is root if the key does not yet exist in the cache.
save_model_default_signature = True
else:
keras_cache = serialization_cache[_KERAS_CACHE_KEY]
if layer in keras_cache:
return keras_cache[layer]
serialized_attr = keras_cache[layer] = SerializedAttributes.new(layer)
if _should_skip_serialization(layer):
return serialized_attr
object_dict = _wrap_layer_objects(layer, serialization_cache)
try:
function_dict = _wrap_layer_functions(layer, serialization_cache,
save_model_default_signature)
except (ValueError, TypeError) as e:
logging.warning('Skipping full serialization of object {}, because an '
'error occurred while tracing layer functions. Error '
'message: {}'.format(layer, e.message))
else:
# Add checkpointable objects and functions to the SerializedAttribute object
# only if all functions are successfully traced.
# The `set_and_validate_*` function ensures that all required attributes are
# exported with the correct type.
serialized_attr.set_and_validate_objects(object_dict)
serialized_attr.set_and_validate_functions(function_dict)
return serialized_attr
def _should_skip_serialization(layer):
"""Skip serializing extra objects and functions if layer inputs aren't set."""
if isinstance(layer, training_lib.Model):
try:
# pylint:disable=pointless-statement
layer.inputs
layer.input_names
# pylint:enable=pointless-statement
except AttributeError:
# If the model does not have inputs set, because it was not called or its
# input shapes were not recorded, we won't have a signature so can't trace
# a function. But the user may still save an object with this Model
# attached; we won't fail the whole tf.saved_model.save.
logging.warning('Skipping full serialization of Keras model {}, because '
'its inputs are not defined.'.format(layer))
return True
else:
return False
else:
if not layer.input_spec:
logging.warning('Skipping full serialization of Keras layer {}, because '
'it does not have an input spec defined.'.format(layer))
return True
if not layer.built:
logging.warning('Skipping full serialization of Keras layer {}, because '
'it is not built.'.format(layer))
return True
return False
def _wrap_layer_objects(layer, serialization_cache):
"""Returns extra trackable objects to attach to the serialized layer.
Args:
layer: Keras Layer object.
serialization_cache: Dictionary shared between all objects during
serialization.
Returns:
A dictionary containing all checkpointable objects from a
SerializedAttributes object. See LayerAttributes and ModelAttributes for
entire list of objects
"""
# Wrap all regularization losses as tf.functions.
# First, generate list of all regularization losses in this layer and
# sublayers.
regularization_losses = layer._callable_losses[:] # pylint: disable=protected-access
for child_layer in (
trackable_layer_utils.filter_empty_layer_containers(layer._layers)): # pylint: disable=protected-access
regularization_losses.extend(child_layer._callable_losses) # pylint: disable=protected-access
# Next, wrap all loss functions as tf.functions. Use the serialization cache
# to store already-wrapped functions.
keras_loss_cache = serialization_cache.setdefault('keras_losses', {})
wrapped_loss_functions = []
for loss_fn in regularization_losses:
if loss_fn in keras_loss_cache:
wrapped_loss_functions.append(keras_loss_cache[loss_fn])
else:
wrapped_loss = _wrap_unconditional_loss(loss_fn, len(keras_loss_cache))
keras_loss_cache[wrapped_loss] = wrapped_loss
wrapped_loss_functions.append(wrapped_loss)
return dict(
variables=data_structures.ListWrapper(layer.variables),
trainable_variables=data_structures.ListWrapper(
layer.trainable_variables),
non_trainable_variables=data_structures.ListWrapper(
layer.non_trainable_variables),
layers=data_structures.ListWrapper(
trackable_layer_utils.filter_empty_layer_containers(
layer._layers)), # pylint: disable=protected-access
metrics=data_structures.ListWrapper(layer.metrics),
regularization_losses=data_structures.ListWrapper(
wrapped_loss_functions))
def _wrap_layer_functions(layer, serialization_cache,
save_model_default_signature=False):
"""Returns dict of wrapped layer call function and losses in tf.functions.
Args:
layer: Keras Layer object.
serialization_cache: Dictionary shared between all objects during
serialization.
save_model_default_signature: Whether to save traced model call function.
Returns:
A dictionary containing all keras tf.functions to serialize. See
LayerAttributes and ModelAttributes for the list of all attributes.
"""
# Reset the losses of the layer and its children. The call function in each
# child layer is replaced with tf.functions.
original_attrs = _replace_child_layer_functions(layer, serialization_cache)
original_layer_losses = layer._losses[:] # pylint: disable=protected-access
with trackable.no_automatic_dependency_tracking_scope(layer):
layer._losses = [] # pylint: disable=protected-access
# Note that eager losses do not need to be saved since these functions
# create symbolic losses.
# Wrap all the layer call and activity regularizer functions.
call_fn_with_losses = _wrap_call_and_conditional_losses(layer)
fns = {'call_and_return_conditional_losses': call_fn_with_losses,
'__call__': _extract_outputs_from_fn(layer, call_fn_with_losses)}
if save_model_default_signature:
fns['_default_save_signature'] = saving_utils.trace_model_call(layer)
else:
fns['_default_save_signature'] = None
if layer.activity_regularizer is not None:
fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
fns['call_and_return_all_conditional_losses'] = (
_append_activity_regularizer_loss(
layer, call_fn_with_losses, fns['activity_regularizer_fn']))
else:
fns['activity_regularizer_fn'] = None
fns['call_and_return_all_conditional_losses'] = call_fn_with_losses
# Manually trigger traces before restoring the overwritten functions. The
# functions are traced within the layer call context to ensure that layer
# functions (e.g. add_loss) behave as though running in graph mode.
with base_layer_utils.call_context().enter(layer, None, build_graph=True):
for fn in fns.values():
if fn is not None and fn.input_signature is not None:
fn.get_concrete_function()
# Restore overwritten functions/losses
with trackable.no_automatic_dependency_tracking_scope(layer):
layer._losses = original_layer_losses # pylint: disable=protected-access
_restore_child_layer_functions(original_attrs)
return fns
def _replace_child_layer_functions(layer, serialization_cache):
"""Replaces functions in the children layers with wrapped tf.functions.
This step allows functions from parent layers to reference the wrapped
functions from their children layers instead of retracing the ops.
This function also resets all losses stored in the layer. These are stored in
the returned dictionary. Use `_restore_child_layer_functions` to restore
the original attributes.
Args:
layer: Keras Layer object.
serialization_cache: Dictionary shared between all objects during
serialization.
Returns:
Dictionary mapping layer objects -> original functions and losses:
{ Child layer 1: {
'losses': Original losses,
'call': Original call function
'activity_regularizer': Original activity regularizer},
Child layer 2: ...
}
"""
original_attrs = {}
for child_layer in trackable_layer_utils.filter_empty_layer_containers(
layer._layers): # pylint: disable=protected-access
# Save symbolic layer losses, which will be restored to maintain the same
# state.
original_attrs[child_layer] = {'losses': child_layer._losses[:]} # pylint: disable=protected-access
if child_layer not in serialization_cache[_KERAS_CACHE_KEY]:
layer_fns = (serialize_all_attributes(child_layer, serialization_cache)
.functions)
else:
layer_fns = serialization_cache[_KERAS_CACHE_KEY][child_layer].functions
if not layer_fns:
# This indicates either:
# - circular dependency, which means the current layer's functions
# should be wrapped first.
# - Child layer's inputs are not defined, so its functions have not been
# wrapped. In this case, no replacement is necessary so move on to the
# next child.
continue
original_attrs[child_layer]['call'] = child_layer.call
original_attrs[child_layer]['activity_regularizer'] = (
child_layer.activity_regularizer)
with trackable.no_automatic_dependency_tracking_scope(child_layer):
child_layer.activity_regularizer = layer_fns.get(
'activity_regularizer_fn')
child_layer.call = _use_wrapped_call(
child_layer, layer_fns['call_and_return_conditional_losses'])
child_layer._losses = [] # pylint: disable=protected-access
return original_attrs
def _restore_child_layer_functions(original_attrs):
"""Restores attributes replaced with `_replace_child_layer_functions`."""
for child_layer, attrs in original_attrs.items():
with trackable.no_automatic_dependency_tracking_scope(child_layer):
child_layer._losses = attrs['losses'] # pylint: disable=protected-access
if 'call' in attrs:
child_layer.call = attrs['call']
child_layer.activity_regularizer = attrs['activity_regularizer']
def _use_wrapped_call(layer, call_fn):
"""Creates fn that adds the losses returned by call_fn & returns the outputs.
Args:
layer: A Keras layer object
call_fn: tf.function returned by _wrap_call_and_conditional_losses.
Returns:
function that calls call_fn and returns the outputs. Losses returned by
call_fn are added to the layer losses.
"""
def wrapped_call(inputs, *args, **kwargs):
"""Returns the outputs from the call_fn, and adds the losses."""
if layer._expects_training_arg: # pylint: disable=protected-access
training = kwargs.pop('training', None)
if training is None:
training = K.learning_phase()
training = math_ops.cast(training, dtypes.bool)
outputs, losses = call_fn(inputs, training=training, *args, **kwargs)
else:
outputs, losses = call_fn(inputs)
layer.add_loss(losses, inputs)
return outputs
return wrapped_call
def _wrap_call_and_conditional_losses(layer):
"""Wraps call function that returns a tuple of (outputs, losses).
The losses returned are conditional on the inputs passed to the call function.
Unconditional losses (e.g. weight regularizeration) are wrapped separately.
Args:
layer: a Keras layer object
Returns:
call function that returns outputs and conditional losses -- excludes
activity regularizer
"""
if isinstance(layer, RevivedLayer):
return layer.call_and_return_conditional_losses
if (isinstance(layer.call, def_function.Function) and
layer.call.input_signature is not None):
input_signature = layer.call.input_signature
else:
if (isinstance(layer, training_lib.Model) and
saving_utils.model_input_signature(layer) is not None):
input_signature = saving_utils.model_input_signature(layer)
else:
input_signature = [nest.map_structure(
lambda x: input_spec.to_tensor_spec(x, layer.dtype),
layer.input_spec)]
# If input spec is too general, then don't define an input signature.
for spec in nest.flatten(input_signature):
if spec.shape == tensor_shape.TensorShape(None):
input_signature = None
break
if input_signature is not None and layer._expects_training_arg: # pylint: disable=protected-access
input_signature.append(
tensor_spec.TensorSpec(shape=[], dtype=dtypes.bool))
# Create function that generates both outputs and losses
layer_call = layer.call
if layer._expects_training_arg: # pylint: disable=protected-access
def call_and_return_conditional_losses(inputs, training):
_set_symbolic_learning_phase(training)
return layer_call(inputs, training=training), layer.get_losses_for(inputs)
else:
def call_and_return_conditional_losses(inputs):
K.set_learning_phase(0)
return layer_call(inputs), layer.get_losses_for(inputs)
return def_function.Function(
call_and_return_conditional_losses,
'{}_layer_call_and_return_conditional_losses'.format(layer.name),
input_signature=input_signature,
# TODO(kathywu): Investigate autograph error.
autograph=False)
def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
"""Returns a function that returns only call function outputs."""
if isinstance(layer, RevivedLayer):
return layer._original_call # pylint: disable=protected-access
if layer._expects_training_arg: # pylint: disable=protected-access
def call(inputs, training):
return call_and_return_conditional_losses(inputs, training)[0]
else:
def call(inputs):
return call_and_return_conditional_losses(inputs)[0]
return def_function.Function(
call, '{}_layer_call_fn'.format(layer.name),
input_signature=call_and_return_conditional_losses.input_signature,
# TODO(kathywu): Investigate autograph error.
autograph=False)
def _set_symbolic_learning_phase(value):
"""Set learning phase to a tensor value (for internal use only).
This is used when wrapping call functions as tf.functions that have training
as a tensor input. Thus, when `learning_phase()` is called, the training
tensor is returned. This function is called when saving a model to SavedModel.
Args:
value: A Tensor object.
Raises:
ValueError: If the input value is not a graph tensor
"""
graph = K.get_graph()
if not isinstance(value, ops.Tensor):
raise ValueError('Symbolic learning phase must be a graph tensor.')
K._GRAPH_LEARNING_PHASES[graph] = value # pylint: disable=protected-access
def _append_activity_regularizer_loss(
layer, call_fn_with_losses, activity_regularizer_fn):
"""Appends activity regularizer loss to losses returned by the wrapped fn."""
def fn(*args):
outputs, losses = call_fn_with_losses(*args)
losses.append(activity_regularizer_fn(outputs))
return outputs, losses
return def_function.Function(
fn,
'{}_layer_call_and_return_all_conditional_losses'.format(layer.name),
input_signature=call_fn_with_losses.input_signature,
# TODO(kathywu): Investigate autograph error.
autograph=False)
def _wrap_unconditional_loss(loss_fn, index):
"""Wraps callable/unconditonal loss, returning a serializable function."""
# Extract original loss function from partial function
fn = loss_fn.args[0] if isinstance(loss_fn, functools.partial) else loss_fn
if isinstance(fn, def_function.Function):
return fn
else:
return def_function.Function(
fn, 'loss_fn_{}'.format(index), input_signature=[])
def _wrap_activity_regularizer(layer):
"""Wraps the activity regularizer."""
if isinstance(layer.activity_regularizer, def_function.Function):
return layer.activity_regularizer
return def_function.Function(
layer.activity_regularizer,
'{}_activity_regularizer'.format(layer.name),
input_signature=[tensor_spec.TensorSpec(None, layer.dtype or K.floatx())])
def load_from_saved_model_v2(path):
"""Loads Keras objects from a SavedModel.
Any Keras layer or model saved to the SavedModel will be loaded back
as Keras objects. Other objects are loaded as regular trackable objects (same
as `tf.saved_model.load`).
Currently, Keras saving/loading only retains the Keras object's weights,
losses, and call function.
The loaded model can be re-compiled, but the original optimizer, compiled loss
functions, and metrics are not retained. This is temporary, and `model.save`
will soon be able to serialize compiled models.
Args:
path: Path to SavedModel.
Returns:
Object loaded from SavedModel.
"""
# TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
# TODO(kathywu): Add code to load from objects that contain all endpoints
return load.load_internal(path, loader_cls=KerasObjectLoader)
class KerasObjectLoader(load.Loader):
"""Loader that recreates Keras objects."""
def __init__(self, *args, **kwargs):
super(KerasObjectLoader, self).__init__(*args, **kwargs)
self._finalize()
def _finalize(self):
# pylint: disable=protected-access
for node in self._nodes:
if isinstance(node, RevivedLayer):
losses = node._serialized_attributes.get('regularization_losses', [])
for loss in losses:
node.add_loss(loss)
# Use wrapped activity regularizer function if the layer's activity
# regularizer wasn't created during initialization.
if node.activity_regularizer is None:
node.activity_regularizer = getattr(node.keras_api,
'activity_regularizer_fn', None)
if isinstance(node, RevivedModel):
# Since this revived object is technically a subclassed model (even if
# the original model is functional/sequential), inputs should be set.
input_signature = (
node.keras_api.call_and_return_conditional_losses.input_signature[0]
)
node._set_inputs(input_signature)
# pylint: enable=protected-access
def _recreate_base_user_object(self, proto):
revived_classes = {
'_tf_keras_layer': (RevivedLayer, base_layer.Layer),
'_tf_keras_network': (RevivedNetwork, network_lib.Network),
'_tf_keras_model': (RevivedModel, training_lib.Model)
}
parent_classes = revived_classes.get(proto.identifier, None)
if parent_classes is not None:
parent_classes = revived_classes[proto.identifier]
metadata = json.loads(proto.metadata)
revived_cls = type(
compat.as_str(metadata['class_name']),
parent_classes,
{'__setattr__': parent_classes[1].__setattr__})
obj = revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access
return obj, revived_cls._revive_setter # pylint: disable=protected-access
return super(KerasObjectLoader, self)._recreate_base_user_object(proto)
# TODO(kathywu): Centrally define keys and functions for both serialization and
# deserialization.
class RevivedLayer(object):
"""Keras layer loaded from a SavedModel."""
@classmethod
def _init_from_metadata(cls, metadata):
"""Create revived layer from metadata stored in the SavedModel proto."""
init_args = dict(
name=metadata['name'],
trainable=metadata['trainable'])
if metadata.get('dtype') is not None:
init_args['dtype'] = metadata['dtype']
if metadata.get('batch_input_shape') is not None:
init_args['batch_input_shape'] = metadata['batch_input_shape']
revived_obj = cls(**init_args)
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
# pylint:disable=protected-access
revived_obj._expects_training_arg = metadata['expects_training_arg']
if metadata.get('config') is not None:
revived_obj._config = metadata['config']
if metadata.get('input_spec') is not None:
revived_obj.input_spec = input_spec.InputSpec.from_config(
metadata['input_spec'])
if metadata.get('activity_regularizer') is not None:
revived_obj.activity_regularizer = regularizers.deserialize(
metadata['activity_regularizer'])
# Store attributes revived from SerializedAttributes in a un-tracked
# dictionary. The attributes are the ones listed in CommonEndpoints or
# "keras_api" for keras-specific attributes.
revived_obj._serialized_attributes = {}
# pylint:enable=protected-access
return revived_obj
def _revive_setter(self, name, value):
"""Reattaches attributes from the SavedModel to the newly revived object."""
if (name in CommonEndpoints.all_functions or
name in CommonEndpoints.all_checkpointable_objects or
name == _KERAS_ATTR):
self._serialized_attributes[name] = value
else:
setattr(self, name, value)
@property
def keras_api(self):
return self._serialized_attributes[_KERAS_ATTR]
def get_config(self):
if hasattr(self, '_config'):
return self._config
else:
raise NotImplementedError
def call(self, inputs, *args, **kwargs):
"""Calls the revived layer and add conditional losses."""
call_fn = _use_wrapped_call(
self, self.keras_api.call_and_return_conditional_losses)
return call_fn(inputs, *args, **kwargs)
class RevivedNetwork(RevivedLayer):
"""Keras network of layers loaded from a SavedModel."""
@classmethod
def _init_from_metadata(cls, metadata):
"""Create revived network from metadata stored in the SavedModel proto."""
# TODO(kathywu): Refactor logic here so that RevivedNetwork uses the
revived_obj = cls(name=metadata['name'])
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
# pylint:disable=protected-access
if metadata.get('dtype') is not None:
revived_obj._dtype = metadata['dtype']
revived_obj.trainable = metadata['trainable']
revived_obj._expects_training_arg = metadata['expects_training_arg']
if metadata.get('config') is not None:
revived_obj._config = metadata['config']
if metadata.get('activity_regularizer') is not None:
revived_obj.activity_regularizer = regularizers.deserialize(
metadata['activity_regularizer'])
# Store attributes revived from SerializedAttributes in a un-tracked
# dictionary. The attributes are the ones listed in CommonEndpoints or
# "keras_api" for keras-specific attributes.
revived_obj._serialized_attributes = {}
# pylint:enable=protected-access
return revived_obj
class RevivedModel(RevivedNetwork):
"""Keras model loaded from a SavedModel."""
@classmethod
def _init_from_metadata(cls, metadata):
"""Create revived model from metadata stored in the SavedModel proto."""
revived_obj = super(RevivedModel, cls)._init_from_metadata(metadata)
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
if 'training_config' in metadata:
revived_obj._training_config = metadata['training_config'] # pylint:disable=protected-access
return revived_obj

View File

@ -28,19 +28,28 @@ from tensorflow.python import keras
from tensorflow.python import tf2
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import regularizers
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import training as model_lib
from tensorflow.python.keras.optimizer_v2 import adadelta
from tensorflow.python.keras.saving import saved_model as keras_saved_model
from tensorflow.python.keras.utils import mode_keys
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load as tf_load
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import model_utils
from tensorflow.python.saved_model import save as tf_save
from tensorflow.python.training import training as training_module
@ -186,7 +195,7 @@ class TestModelSavingandLoading(test.TestCase):
# For now, saving subclassed model should raise an error. It should be
# avoided later with loading from SavedModel.pb.
class SubclassedModel(training.Model):
class SubclassedModel(model_lib.Model):
def __init__(self):
super(SubclassedModel, self).__init__()
@ -205,10 +214,15 @@ class TestModelSavingandLoading(test.TestCase):
class LayerWithLearningPhase(keras.engine.base_layer.Layer):
def call(self, x):
phase = keras.backend.learning_phase()
def build(self, input_shape):
self.input_spec = keras.layers.InputSpec(shape=[None] * len(input_shape))
self.built = True
def call(self, x, training=None):
if training is None:
training = keras.backend.learning_phase()
output = tf_utils.smart_cond(
phase, lambda: x * 0, lambda: array_ops.identity(x))
training, lambda: x * 0, lambda: array_ops.identity(x))
if not context.executing_eagerly():
output._uses_learning_phase = True # pylint: disable=protected-access
return output
@ -538,5 +552,171 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
self.assertAllClose(ref_predict, predictions, atol=1e-05)
@test_util.run_all_in_graph_and_eager_modes
class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
def _save_model_dir(self, dirname='saved_model'):
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
return os.path.join(temp_dir, dirname)
@keras_parameterized.run_with_all_model_types
def test_model_save_and_load(self):
input_arr = np.random.random((1, 3)).astype(np.float32)
target_arr = np.random.random((1, 4)).astype(np.float32)
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
model.layers[-1].activity_regularizer = regularizers.get('l2')
model.activity_regularizer = regularizers.get('l2')
model.compile(
loss='mse',
optimizer='rmsprop')
model.train_on_batch(input_arr, target_arr)
def callable_loss():
return math_ops.reduce_sum(model.weights[0])
model.add_loss(callable_loss)
saved_model_dir = self._save_model_dir()
tf_save.save(model, saved_model_dir)
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
self.evaluate(variables.variables_initializer(loaded.variables))
self.assertAllClose(self.evaluate(model.weights),
self.evaluate(loaded.weights))
input_arr = constant_op.constant(
np.random.random((1, 3)).astype(np.float32))
self.assertAllClose(self.evaluate(model(input_arr)),
self.evaluate(loaded(input_arr)))
# Validate losses. The order of conditional losses may change between the
# model and loaded model, so sort the losses first.
if context.executing_eagerly():
self.assertAllClose(sorted(self.evaluate(model.losses)),
sorted(self.evaluate(loaded.losses)))
else:
self.assertAllClose(self.evaluate(model.get_losses_for(None)),
self.evaluate(loaded.get_losses_for(None)))
self.assertAllClose(
sorted(self.evaluate(model.get_losses_for(input_arr))),
sorted(self.evaluate(loaded.get_losses_for(input_arr))))
def test_trainable_weights(self):
layer = keras.layers.Dense(4, name='custom_layer')
layer.build([3,])
layer.add_weight(
'extra_weight', shape=[],
initializer=init_ops.constant_initializer(11),
trainable=True)
layer.add_weight(
'extra_weight_2', shape=[],
initializer=init_ops.constant_initializer(12),
trainable=False)
saved_model_dir = self._save_model_dir()
self.evaluate(variables.variables_initializer(layer.variables))
tf_save.save(layer, saved_model_dir)
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
self.evaluate(variables.variables_initializer(loaded.variables))
equal_attrs = ['name', '_expects_training_arg', 'trainable']
for attr in equal_attrs:
self.assertEqual(getattr(layer, attr), getattr(loaded, attr))
all_close = ['weights', 'trainable_weights', 'non_trainable_weights']
for attr in all_close:
self.assertAllClose(self.evaluate(getattr(layer, attr)),
self.evaluate(getattr(loaded, attr)))
def test_maintains_losses(self):
"""Tests that the layer losses do not change before and after export."""
class LayerWithLoss(keras.layers.Layer):
def call(self, inputs):
self.add_loss(math_ops.reduce_sum(inputs), inputs)
return inputs
model = keras.models.Sequential([LayerWithLoss()])
model.compile(
loss='mse',
optimizer='rmsprop')
input_arr = np.random.random((1, 3)).astype(np.float32)
target_arr = np.random.random((1, 3)).astype(np.float32)
# Test that symbolic losses are maintained (train_on_batch saves symbolic
# losses.)
model.train_on_batch(input_arr, target_arr)
previous_losses = model.losses[:]
saved_model_dir = self._save_model_dir()
tf_save.save(model, saved_model_dir)
self.assertAllEqual(previous_losses, model.losses)
if context.executing_eagerly():
# Test that eager losses are maintained.
model(input_arr) # Calls model eagerly, creating eager losses.
previous_losses = model.losses[:]
tf_save.save(model, saved_model_dir)
self.assertAllEqual(previous_losses, model.losses)
def test_layer_with_learning_phase(self):
layer = LayerWithLearningPhase()
layer.build([None, None])
saved_model_dir = self._save_model_dir()
tf_save.save(layer, saved_model_dir)
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
input_arr = array_ops.ones((4, 3))
# Run the layer, and use the keras backend learing phase
keras.backend.set_learning_phase(0)
self.assertAllEqual(input_arr, loaded(input_arr))
keras.backend.set_learning_phase(1)
self.assertAllEqual(array_ops.zeros((4, 3)), loaded(input_arr))
# Run the layer while explicitly setting the training argument
self.assertAllEqual(
input_arr, loaded(input_arr, training=constant_op.constant(False)))
self.assertAllEqual(
array_ops.zeros((4, 3)),
loaded(input_arr, training=constant_op.constant(True)))
@keras_parameterized.run_with_all_model_types
def test_standard_loader(self):
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
model.activity_regularizer = regularizers.get('l2')
def eager_loss():
return math_ops.reduce_sum(model.weights[0])
model.add_loss(eager_loss)
# Call predict to ensure that all layers are built and inputs are set.
model.predict(np.random.random((1, 3)))
saved_model_dir = self._save_model_dir()
tf_save.save(model, saved_model_dir)
loaded = tf_load.load(saved_model_dir)
self.evaluate(variables.variables_initializer(loaded.variables))
all_close = ['variables', 'trainable_variables',
'non_trainable_variables']
for attr in all_close:
self.assertAllClose(self.evaluate(getattr(model, attr)),
self.evaluate(getattr(loaded.keras_api, attr)))
self.assertLen(loaded.regularization_losses, 1)
expected_layers = len(model.layers)
if testing_utils.get_model_type() == 'sequential':
# The autogenerated Input layer is hidden in the model.layers list,
# but included in the loaded sub-layers.
expected_layers += 1
self.assertEqual(expected_layers, len(loaded.keras_api.layers))
input_arr = array_ops.ones((4, 3))
training_bool = constant_op.constant(False)
if model._expects_training_arg:
call_args = [input_arr, training_bool]
else:
call_args = [input_arr]
self.assertAllClose(self.evaluate(model(input_arr)),
self.evaluate(loaded(*call_args)))
if __name__ == '__main__':
test.main()

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Utils related to keras model saving."""
from __future__ import absolute_import
from __future__ import division
@ -22,6 +21,9 @@ import collections
from tensorflow.python.eager import def_function
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import optimizers
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@ -43,7 +45,53 @@ def extract_model_metrics(model):
# TODO(psv/kathywu): use this implementation in model to estimator flow.
# We are not using model.metrics here because we want to exclude the metrics
# added using `add_metric` API.
return {m.name: m for m in model._compile_metric_functions}
return {m.name: m for m in model._compile_metric_functions} # pylint: disable=protected-access
def model_input_signature(model):
"""Inspect model to get its input signature.
The model's input signature is a list with a single (possibly-nested) object.
This is due to the Keras-enforced restriction that tensor inputs must be
passed in as the first argument.
For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]
Args:
model: Keras Model object
Returns:
A list containing either a single TensorSpec or an object with nested
TensorSpecs.
"""
try:
inputs = model.inputs
input_names = model.input_names
except AttributeError:
return None
flat_inputs = nest.flatten(inputs)
flat_input_names = nest.flatten(input_names)
flat_input_specs = []
for input_tensor, input_name in zip(flat_inputs, flat_input_names):
# If the user has not explicitly provided the input_signature, we
# create it from the inputs. We make sure to set the first dimension
# (batch) to None here, as in serving or retraining, batch should not
# be fixed. See b/132783590 for context.
input_shape = [None] + input_tensor.shape[1:].as_list()
flat_input_specs.append(tensor_spec.TensorSpec(
shape=input_shape, dtype=input_tensor.dtype,
name=input_name))
input_specs = nest.pack_sequence_as(structure=inputs,
flat_sequence=flat_input_specs)
# Return a list with a single element as the model's input signature.
if isinstance(input_specs, collections.Sequence) and len(input_specs) == 1:
# Note that the isinstance check filters out single-element dictionaries,
# which should also be wrapped as a single-element list.
return input_specs
else:
return [input_specs]
def trace_model_call(model, input_signature=None):
@ -65,37 +113,14 @@ def trace_model_call(model, input_signature=None):
input_signature = model.call.input_signature
if input_signature is None:
try:
inputs = model.inputs
input_names = model.input_names
except AttributeError:
raise ValueError(
'Model {} cannot be saved because the input shapes have not been '
'set. Usually, input shapes are automatically determined from calling'
' .fit() or .predict(). To manually set the shapes, call '
'model._set_inputs(inputs).'.format(model))
flat_inputs = nest.flatten(inputs)
flat_input_names = nest.flatten(input_names)
flat_input_specs = []
for input_tensor, input_name in zip(flat_inputs, flat_input_names):
# If the user has not explicitly provided the input_signature, we
# create it from the inputs. We make sure to set the first dimension
# (batch) to None here, as in serving or retraining, batch should not
# be fixed. See b/132783590 for context.
input_shape = [None] + input_tensor.shape[1:].as_list()
flat_input_specs.append(tensor_spec.TensorSpec(
shape=input_shape, dtype=input_tensor.dtype,
name=input_name))
input_specs = nest.pack_sequence_as(structure=inputs,
flat_sequence=flat_input_specs)
# The input signature of the call function is a list with one element, since
# all tensor inputs must be passed in as the first argument. Single-element
# dictionaries and other non-sequence types must also be wrapped.
if (len(input_specs) > 1
or not isinstance(input_specs, collections.Sequence)):
input_signature = [input_specs]
else:
input_signature = input_specs
input_signature = model_input_signature(model)
if input_signature is None:
raise ValueError(
'Model {} cannot be saved because the input shapes have not been '
'set. Usually, input shapes are automatically determined from calling'
' .fit() or .predict(). To manually set the shapes, call '
'model._set_inputs(inputs).'.format(model))
# TODO(mdan): Should the model's call be autographed by default?
@def_function.function(input_signature=input_signature, autograph=False)
@ -113,3 +138,55 @@ def trace_model_call(model, input_signature=None):
return {name: output for name, output in zip(output_names, outputs_list)}
return _wrapped_model
def model_metadata(model, include_optimizer=True, require_config=True):
"""Returns a dictionary containing the model metadata."""
from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.optimizer_v2 import optimizer_v2 # pylint: disable=g-import-not-at-top
model_config = {'class_name': model.__class__.__name__}
try:
model_config['config'] = model.get_config()
except NotImplementedError as e:
if require_config:
raise e
metadata = dict(
keras_version=str(keras_version),
backend=K.backend(),
model_config=model_config)
if model.optimizer and include_optimizer:
if isinstance(model.optimizer, optimizers.TFOptimizer):
logging.warning(
'TensorFlow optimizers do not '
'make it possible to access '
'optimizer attributes or optimizer state '
'after instantiation. '
'As a result, we cannot save the optimizer '
'as part of the model save file. '
'You will have to compile your model again after loading it. '
'Prefer using a Keras optimizer instead '
'(see keras.io/optimizers).')
else:
metadata['training_config'] = {
'loss': model.loss,
# pylint: disable=protected-access
'metrics': model._compile_metrics,
'weighted_metrics': model._compile_weighted_metrics,
# pylint: enable=protected-access
'sample_weight_mode': model.sample_weight_mode,
'loss_weights': model.loss_weights,
}
if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
raise NotImplementedError(
'As of now, Optimizers loaded from SavedModel cannot be saved. '
'If you\'re calling `model.save` or `tf.keras.models.save_model`, '
'please set the `include_optimizer` option to `False`. For '
'`tf.saved_model.save`, delete the optimizer from the model.')
else:
optimizer_config = {
'class_name': model.optimizer.__class__.__name__,
'config': model.optimizer.get_config()}
metadata['training_config']['optimizer_config'] = optimizer_config
return metadata

View File

@ -99,7 +99,7 @@ class _WrapperFunction(function.ConcreteFunction):
return super(_WrapperFunction, self)._call_flat(args, captured_inputs)
class _Loader(object):
class Loader(object):
"""Helper class to load an object-based SavedModel."""
def __init__(self, object_graph_proto, saved_model_proto, export_dir):
@ -249,7 +249,7 @@ class _Loader(object):
# Note: if an object has an attribute `__call__` add a class method
# that allows `obj()` syntax to work. This is done per-instance to
# allow `callable` to be used to find out if an object is callable.
if reference.local_name == "__call__":
if reference.local_name == "__call__" and not callable(obj):
setattr(type(obj), "__call__", _call_attribute)
def _restore_checkpoint(self):
@ -308,16 +308,20 @@ class _Loader(object):
"""Instantiates a SavedUserObject."""
looked_up = revived_types.deserialize(proto)
if looked_up is None:
# Note: each user object has its own class. This allows to make each one
# individually callable by adding a `__call__` method to the classes of
# the objects instances that have a `__call__` property.
class _UserObject(tracking.AutoTrackable):
pass
return _UserObject(), setattr
return self._recreate_base_user_object(proto)
return looked_up
def _recreate_base_user_object(self, proto):
del proto
# Note: each user object has its own class. This allows to make each one
# individually callable by adding a `__call__` method to the classes of
# the objects instances that have a `__call__` property.
class _UserObject(tracking.AutoTrackable):
pass
return _UserObject(), setattr
def _recreate_asset(self, proto):
filename = os.path.join(
saved_model_utils.get_assets_dir(self._export_dir),
@ -386,7 +390,7 @@ class _RestoredResource(tracking.TrackableResource):
def _initialize(self):
raise RuntimeError()
def _list_functions_for_serialization(self):
def _list_functions_for_serialization(self, unused_serialization_cache):
# Overwrite this method to avoid the implementation of
# base class to re-wrap the polymorphic functions into
# another layer of `tf.function`.
@ -426,6 +430,22 @@ def load(export_dir, tags=None):
assert 6. == imported.f(x=tf.constant(2.)).numpy()
```
_Loading Keras models_
Keras models are trackable, so they can be saved to SavedModel. The object
returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have
`.fit`, `.predict`, etc. methods). A few attributes and functions are still
available: `.variables`, `.trainable_variables` and `.__call__`.
```python
model = tf.keras.Model(...)
tf.saved_model.save(model, path)
imported = tf.saved_model.load(path)
outputs = imported(inputs)
```
Use `tf.keras.models.load_model` to restore the Keras model.
_Importing SavedModels from TensorFlow 1.x_
SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
@ -462,6 +482,11 @@ def load(export_dir, tags=None):
Raises:
ValueError: If `tags` don't match a MetaGraph in the SavedModel.
"""
return load_internal(export_dir, tags)
def load_internal(export_dir, tags=None, loader_cls=Loader):
"""Loader implementation."""
if tags is not None and not isinstance(tags, set):
# Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
# sequences for nest.flatten, so we put those through as-is.
@ -479,9 +504,9 @@ def load(export_dir, tags=None):
.format(export_dir, meta_graph_def.meta_info_def.tags, tags))
object_graph_proto = meta_graph_def.object_graph_def
with ops.init_scope():
loader = _Loader(object_graph_proto,
saved_model_proto,
export_dir)
loader = loader_cls(object_graph_proto,
saved_model_proto,
export_dir)
root = loader.get(0)
root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
root.tensorflow_git_version = (

View File

@ -1711,5 +1711,26 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
f(x=constant_op.constant([[-1.]]))["output_0"].numpy())
def test_object_with_extra_dependencies(self):
class Extra(tracking.AutoTrackable):
def _list_extra_dependencies_for_serialization(self, cache):
if self not in cache:
cache[self] = {"a": variables.Variable(5.)}
return cache[self]
root = Extra()
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
save.save(root, path)
imported = load.load(path)
self.assertEqual(5, self.evaluate(imported.a))
root.a = variables.Variable(3.)
with self.assertRaisesRegexp(
ValueError,
"object has an attribute named a, which is reserved."):
save.save(root, path)
if __name__ == "__main__":
test.main()

View File

@ -165,3 +165,7 @@ def deserialize(proto):
if type_registration.should_load(proto):
return (type_registration.from_proto(proto), type_registration.setter)
return None
def registered_identifiers():
return _REVIVED_TYPE_REGISTRY.keys()

View File

@ -91,6 +91,10 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
# Object -> (name -> dep)
self._extra_dependencies = object_identity.ObjectIdentityDictionary()
self._functions = object_identity.ObjectIdentityDictionary()
# Cache shared between objects in the same object graph. This is passed to
# each trackable object's `_list_extra_dependencies_for_serialization` and
# `_list_functions_for_serialization` function.
self._serialization_cache = object_identity.ObjectIdentityDictionary()
def add_object(self, parent_node, name_in_parent, subgraph_root):
"""Attach an object to `parent_node`, overriding any existing dependency."""
@ -99,11 +103,23 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
def list_dependencies(self, obj):
"""Overrides a parent method to include `add_object` objects."""
extra_dependencies = self._extra_dependencies.get(obj, {})
extra_dependencies = self.list_extra_dependencies(obj)
extra_dependencies.update(self._extra_dependencies.get(obj, {}))
used_names = set()
for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj):
used_names.add(name)
if name in extra_dependencies:
# Extra dependencies (except for `.signatures`, which is always added
# when saving) should not have naming conflicts with dependencies
# defined by the user.
if name != signature_serialization.SIGNATURE_ATTRIBUTE_NAME:
raise ValueError(
"Error when exporting object {} of with identifier={}. The object"
" has an attribute named {}, which is reserved. List of all "
"reserved attributes: {}".format(
obj, obj._object_identifier, # pylint: disable=protected-access
name, extra_dependencies.keys()))
yield base.TrackableReference(name, extra_dependencies[name])
else:
yield base.TrackableReference(name, dep)
@ -112,10 +128,15 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
continue
yield base.TrackableReference(name, dep)
def list_extra_dependencies(self, obj):
return obj._list_extra_dependencies_for_serialization( # pylint: disable=protected-access
self._serialization_cache)
def list_functions(self, obj):
obj_functions = self._functions.get(obj, None)
if obj_functions is None:
obj_functions = obj._list_functions_for_serialization() # pylint: disable=protected-access
obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
self._serialization_cache)
self._functions[obj] = obj_functions
return obj_functions
@ -614,10 +635,13 @@ def _write_object_proto(obj, proto, asset_file_def_index):
registered_type_proto = revived_types.serialize(obj)
if registered_type_proto is None:
# Fallback for types with no matching registration
# pylint:disable=protected-access
registered_type_proto = saved_object_graph_pb2.SavedUserObject(
identifier="_generic_user_object",
identifier=obj._object_identifier,
version=versions_pb2.VersionDef(
producer=1, min_consumer=1, bad_consumers=[]))
producer=1, min_consumer=1, bad_consumers=[]),
metadata=obj._tracking_metadata)
# pylint:enable=protected-access
proto.user_object.CopyFrom(registered_type_proto)

View File

@ -204,7 +204,7 @@ class _SignatureMap(collections.Mapping, base.Trackable):
def __repr__(self):
return "_SignatureMap({})".format(self._signatures)
def _list_functions_for_serialization(self):
def _list_functions_for_serialization(self, unused_serialization_cache):
return {
key: value for key, value in self.items()
if isinstance(value, (def_function.Function, defun.ConcreteFunction))

View File

@ -32,6 +32,7 @@ from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
# Key where the object graph proto is saved in a TensorBundle
@ -463,6 +464,38 @@ def no_automatic_dependency_tracking(method):
target=method, decorator_func=_method_wrapper)
@tf_contextlib.contextmanager
def no_automatic_dependency_tracking_scope(obj):
"""A context that disables automatic dependency tracking when assigning attrs.
Objects that inherit from Autotrackable automatically creates dependencies
to trackable objects through attribute assignments, and wraps data structures
(lists or dicts) with trackable classes. This scope may be used to temporarily
disable this behavior. This works similar to the decorator
`no_automatic_dependency_tracking`.
Example usage:
```
model = tf.keras.Model()
model.arr1 = [] # Creates a ListWrapper object
with no_automatic_dependency_tracking_scope(model):
model.arr2 = [] # Creates a regular, untracked python list
```
Args:
obj: A trackable object.
Yields:
a scope in which the object doesn't track dependencies.
"""
previous_value = getattr(obj, "_setattr_tracking", True)
obj._setattr_tracking = False # pylint: disable=protected-access
try:
yield
finally:
obj._setattr_tracking = previous_value # pylint: disable=protected-access
class Trackable(object):
"""Base class for `Trackable` objects without automatic dependencies.
@ -548,6 +581,23 @@ class Trackable(object):
# building.
self._self_name_based_restores = set()
@property
def _object_identifier(self):
"""String used to identify this object in a SavedModel.
Generally, the object identifier is constant across objects of the same
class, while the metadata field is used for instance-specific data.
Returns:
String object identifier.
"""
return "_generic_user_object"
@property
def _tracking_metadata(self):
"""String containing object metadata, which is saved to the SavedModel."""
return ""
def _no_dependency(self, value):
"""If automatic dependency tracking is enabled, ignores `value`."""
return value
@ -881,15 +931,50 @@ class Trackable(object):
"""
return {}
def _list_functions_for_serialization(self):
def _list_extra_dependencies_for_serialization(self, serialization_cache):
"""Lists extra dependencies to serialize.
Internal sub-classes can override this method to return extra dependencies
that should be saved with the object during SavedModel serialization. For
example, this is used to save `trainable_variables` in Keras models. The
python property `trainable_variables` contains logic to iterate through the
weights from the model and its sublayers. The serialized Keras model saves
`trainable_weights` as a trackable list of variables.
PLEASE NOTE when overriding this method:
1. This function may only generate new trackable objects the first time it
is called.
2. The returned dictionary must not have naming conflicts with
dependencies tracked by the root. In other words, if the root is
tracking `object_1` with name 'x', and this functions returns
`{'x': object_2}`, an error is raised when saving.
Args:
serialization_cache: A dictionary shared between all objects in the same
object graph. This object is passed to both
`_list_extra_dependencies_for_serialization` and
`_list_functions_for_serialization`.
Returns:
A dictionary mapping attribute names to trackable objects.
"""
del serialization_cache
return dict()
def _list_functions_for_serialization(self, serialization_cache):
"""Lists the functions of this trackable to serialize.
Internal sub-classes can override this with specific logic. E.g.
`AutoTrackable` provides an implementation that returns the `attr`
that return functions.
Args:
serialization_cache: Dictionary passed to all objects in the same object
graph during serialization.
Returns:
A dictionary mapping attribute names to `Function` or
`ConcreteFunction`.
"""
del serialization_cache
return dict()

View File

@ -76,7 +76,7 @@ def _wrap_or_unwrap(value):
elif type(value) == collections.OrderedDict:
return _DictWrapper(value)
elif type(value) == list:
return _ListWrapper(value)
return ListWrapper(value)
else:
return value
# pylint: enable=unidiomatic-typecheck
@ -371,9 +371,9 @@ class List(TrackableDataStructure, collections.Sequence):
# TODO(tomhennigan) Update to collections.UserList?
# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop
# Python 3.4 support (may still be tricky).
class _ListWrapper(List, collections.MutableSequence,
# Shadowed, but there for isinstance checks.
list):
class ListWrapper(List, collections.MutableSequence,
# Shadowed, but there for isinstance checks.
list):
"""Wraps the built-in `list` to support restore-on-create for variables.
Unlike `List`, this sequence type is mutable in the same ways built-in lists
@ -383,7 +383,7 @@ class _ListWrapper(List, collections.MutableSequence,
refuses to save.
On assignment to an attribute of a Model or Trackable object, Python
lists are replaced with _ListWrapper. Wrapping a list in a
lists are replaced with ListWrapper. Wrapping a list in a
`tf.contrib.checkpoint.NoDependency` object prevents this.
"""
@ -393,26 +393,26 @@ class _ListWrapper(List, collections.MutableSequence,
Args:
wrapped_list: The initial value of the data structure. A shallow copy may
be maintained for error checking. `wrapped_list` itself should not be
modified directly after constructing the `_ListWrapper`, and if changes
are detected the `_ListWrapper` will throw an exception on save.
modified directly after constructing the `ListWrapper`, and if changes
are detected the `ListWrapper` will throw an exception on save.
"""
# Monotonic flags which indicate this object would not be restored properly,
# and therefore should throw an error on save to avoid giving the impression
# that restoring it will work.
self._non_append_mutation = False
self._external_modification = False
super(_ListWrapper, self).__init__(wrapped_list)
super(ListWrapper, self).__init__(wrapped_list)
self._last_wrapped_list_snapshot = list(self._storage)
# pylint: disable=protected-access
def __copy__(self):
copied = super(_ListWrapper, self).__copy__()
copied = super(ListWrapper, self).__copy__()
copied._non_append_mutation = self._non_append_mutation
copied._external_modification = self._external_modification
return copied
def __deepcopy__(self, memo):
copied = super(_ListWrapper, self).__deepcopy__(memo)
copied = super(ListWrapper, self).__deepcopy__(memo)
copied._non_append_mutation = self._non_append_mutation
copied._external_modification = self._external_modification
return copied
@ -463,7 +463,7 @@ class _ListWrapper(List, collections.MutableSequence,
"wrap it in a tf.contrib.checkpoint.NoDependency object; it will be "
"automatically un-wrapped and subsequently ignored." % (
self, self._storage, self._last_wrapped_list_snapshot)))
return super(_ListWrapper, self)._checkpoint_dependencies
return super(ListWrapper, self)._checkpoint_dependencies
def __delitem__(self, key):
self._non_append_mutation = True
@ -502,13 +502,13 @@ class _ListWrapper(List, collections.MutableSequence,
def append(self, value):
"""Add a new trackable value."""
self._check_external_modification()
super(_ListWrapper, self).append(value)
super(ListWrapper, self).append(value)
self._update_snapshot()
def extend(self, values):
"""Add a sequence of trackable values."""
self._check_external_modification()
super(_ListWrapper, self).extend(values)
super(ListWrapper, self).extend(values)
self._update_snapshot()
def __imul__(self, y):
@ -518,7 +518,7 @@ class _ListWrapper(List, collections.MutableSequence,
return self
# Relies on super() calling append, which updates the snapshot.
return super(_ListWrapper, self).__imul__(y)
return super(ListWrapper, self).__imul__(y)
def __eq__(self, other):
return self._storage == getattr(other, "_storage", other)
@ -561,7 +561,7 @@ class _ListWrapper(List, collections.MutableSequence,
def _track_value(self, value, name):
"""Allows storage of non-trackable objects."""
try:
value = super(_ListWrapper, self)._track_value(value=value, name=name)
value = super(ListWrapper, self)._track_value(value=value, name=name)
except ValueError:
# Even if this value isn't trackable, we need to make sure
# NoDependency objects get unwrapped.
@ -572,7 +572,7 @@ class _ListWrapper(List, collections.MutableSequence,
def __repr__(self):
return "ListWrapper(%s)" % (repr(self._storage),)
def _list_functions_for_serialization(self):
def _list_functions_for_serialization(self, unused_functions):
return {
str(key): value for key, value in enumerate(self)
if _is_function(value)
@ -653,9 +653,9 @@ class Mapping(TrackableDataStructure, collections.Mapping):
class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
"""Wraps built-in dicts to support restore-on-create for variables.
_DictWrapper is to Mapping as _ListWrapper is to List. Unlike Mapping,
_DictWrapper is to Mapping as ListWrapper is to List. Unlike Mapping,
_DictWrapper allows non-string keys and values and arbitrary mutations (delete
keys, reassign values). Like _ListWrapper, these mutations mean that
keys, reassign values). Like ListWrapper, these mutations mean that
_DictWrapper will raise an exception on save.
"""
@ -827,7 +827,7 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
for key, value in six.iteritems(dict(*args, **kwargs)):
self[key] = value
def _list_functions_for_serialization(self):
def _list_functions_for_serialization(self, unused_serialization_cache):
return {
key: value for key, value in self.items()
if _is_function(value)
@ -860,9 +860,9 @@ def _set_list_item(list_object, index_string, value):
revived_types.register_revived_type(
"trackable_list_wrapper",
lambda obj: isinstance(obj, _ListWrapper),
lambda obj: isinstance(obj, ListWrapper),
versions=[revived_types.VersionedTypeRegistration(
object_factory=lambda proto: _ListWrapper([]),
object_factory=lambda proto: ListWrapper([]),
version=1,
min_producer_version=1,
min_consumer_version=1,

View File

@ -357,14 +357,14 @@ class ListWrapperTest(test.TestCase):
# Skip methods that aren't overridden from object.
continue
if list_method == getattr(data_structures._ListWrapper, name):
if list_method == getattr(data_structures.ListWrapper, name):
not_overridden.append(name)
if not_overridden:
self.fail("_ListWrapper does not override %s" % (not_overridden))
self.fail("ListWrapper does not override %s" % (not_overridden))
def testPickle(self):
original = data_structures._ListWrapper([1, 2])
original = data_structures.ListWrapper([1, 2])
serialized = pickle.dumps(original)
del original
deserialized = pickle.loads(serialized)
@ -372,7 +372,7 @@ class ListWrapperTest(test.TestCase):
def testSameStructure(self):
l = [1]
nest.assert_same_structure(l, data_structures._ListWrapper(copy.copy(l)))
nest.assert_same_structure(l, data_structures.ListWrapper(copy.copy(l)))
def testFunctionCaching(self):
@def_function.function
@ -381,87 +381,87 @@ class ListWrapperTest(test.TestCase):
first_trace = f.get_concrete_function([constant_op.constant(2.)])
second_trace = f.get_concrete_function(
data_structures._ListWrapper([constant_op.constant(3.)]))
data_structures.ListWrapper([constant_op.constant(3.)]))
self.assertIs(first_trace, second_trace)
def testListWrapperBasic(self):
# _ListWrapper, unlike List, compares like the built-in list type (since it
# ListWrapper, unlike List, compares like the built-in list type (since it
# is used to automatically replace lists).
a = tracking.AutoTrackable()
b = tracking.AutoTrackable()
self.assertEqual([a, a],
[a, a])
self.assertEqual(data_structures._ListWrapper([a, a]),
data_structures._ListWrapper([a, a]))
self.assertEqual(data_structures.ListWrapper([a, a]),
data_structures.ListWrapper([a, a]))
self.assertEqual([a, a],
data_structures._ListWrapper([a, a]))
self.assertEqual(data_structures._ListWrapper([a, a]),
data_structures.ListWrapper([a, a]))
self.assertEqual(data_structures.ListWrapper([a, a]),
[a, a])
self.assertNotEqual([a, a],
[b, a])
self.assertNotEqual(data_structures._ListWrapper([a, a]),
data_structures._ListWrapper([b, a]))
self.assertNotEqual(data_structures.ListWrapper([a, a]),
data_structures.ListWrapper([b, a]))
self.assertNotEqual([a, a],
data_structures._ListWrapper([b, a]))
data_structures.ListWrapper([b, a]))
self.assertLess([a], [a, b])
self.assertLess(data_structures._ListWrapper([a]),
data_structures._ListWrapper([a, b]))
self.assertLess(data_structures.ListWrapper([a]),
data_structures.ListWrapper([a, b]))
self.assertLessEqual([a], [a, b])
self.assertLessEqual(data_structures._ListWrapper([a]),
data_structures._ListWrapper([a, b]))
self.assertLessEqual(data_structures.ListWrapper([a]),
data_structures.ListWrapper([a, b]))
self.assertGreater([a, b], [a])
self.assertGreater(data_structures._ListWrapper([a, b]),
data_structures._ListWrapper([a]))
self.assertGreater(data_structures.ListWrapper([a, b]),
data_structures.ListWrapper([a]))
self.assertGreaterEqual([a, b], [a])
self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
data_structures._ListWrapper([a]))
self.assertEqual([a], data_structures._ListWrapper([a]))
self.assertGreaterEqual(data_structures.ListWrapper([a, b]),
data_structures.ListWrapper([a]))
self.assertEqual([a], data_structures.ListWrapper([a]))
self.assertEqual([a], list(data_structures.List([a])))
self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
self.assertIsInstance(data_structures._ListWrapper([a]), list)
self.assertEqual([a, a], data_structures.ListWrapper([a]) + [a])
self.assertEqual([a, a], [a] + data_structures.ListWrapper([a]))
self.assertIsInstance(data_structures.ListWrapper([a]), list)
def testAcceptsNonTrackableContent(self):
l = data_structures._ListWrapper([1, 2, 3])
l = data_structures.ListWrapper([1, 2, 3])
self.assertEqual(l, [1, 2, 3])
def testWrapperChangesList(self):
l = []
l_wrapper = data_structures._ListWrapper(l)
l_wrapper = data_structures.ListWrapper(l)
l_wrapper.append(1)
self.assertEqual([1], l)
def testListChangesWrapper(self):
l = []
l_wrapper = data_structures._ListWrapper(l)
l_wrapper = data_structures.ListWrapper(l)
l.append(1)
self.assertEqual([1], l_wrapper)
def testLayerCollectionWithExternalMutation(self):
l = []
l_wrapper = data_structures._ListWrapper(l)
l_wrapper = data_structures.ListWrapper(l)
layer = core.Dense(1)
l.append(layer)
self.assertEqual([layer], l_wrapper.layers)
def testNotHashable(self):
with self.assertRaises(TypeError):
hash(data_structures._ListWrapper())
hash(data_structures.ListWrapper())
def testDelItem(self):
l = data_structures._ListWrapper([1, 2, 3, 4])
l = data_structures.ListWrapper([1, 2, 3, 4])
del l[0]
self.assertEqual(l, [2, 3, 4])
self.assertUnableToSave(l, "Unable to save .*__delitem__")
def testDelSlice(self):
l = data_structures._ListWrapper([1, 2, 3, 4])
l = data_structures.ListWrapper([1, 2, 3, 4])
del l[2:3]
self.assertEqual(l, [1, 2, 4])
self.assertUnableToSave(l, "Unable to save .*__delslice__")
def testSetSlice_canSaveForNonTrackableItems(self):
l = data_structures._ListWrapper([1, 2, 3, 4])
l = data_structures.ListWrapper([1, 2, 3, 4])
l[:] = 2, 8, 9, 0
self.assertEqual(l, [2, 8, 9, 0])
l._maybe_initialize_trackable() # pylint: disable=protected-access
@ -470,30 +470,30 @@ class ListWrapperTest(test.TestCase):
def testSetSlice_cannotSaveIfTrackableModified(self):
v1 = resource_variable_ops.ResourceVariable(1.)
v2 = resource_variable_ops.ResourceVariable(1.)
l = data_structures._ListWrapper([1, 2, v1, v2])
l = data_structures.ListWrapper([1, 2, v1, v2])
l[:] = 2, 8, 9, v2
self.assertEqual(l, [2, 8, 9, v2])
self.assertUnableToSave(l, "Unable to save .*__setslice__")
def testSetSlice_truncate(self):
l = data_structures._ListWrapper([1, 2, 3, 4])
l = data_structures.ListWrapper([1, 2, 3, 4])
l[:] = []
self.assertEqual(l, [])
def testSetSlice_extend(self):
l = data_structures._ListWrapper([1, 2, 3, 4])
l = data_structures.ListWrapper([1, 2, 3, 4])
l[2:] = 1, 2, 3, 4
self.assertEqual(l, [1, 2, 1, 2, 3, 4])
def testIMulNegative(self):
l = data_structures._ListWrapper([1, 2, 3, 4])
l = data_structures.ListWrapper([1, 2, 3, 4])
l *= -1
self.assertEqual(l, [1, 2, 3, 4] * -1)
self.assertUnableToSave(l, "Unable to save")
def testIMulPositive(self):
v = variables.Variable(1.)
l = data_structures._ListWrapper([1, 2, 3, 4, v])
l = data_structures.ListWrapper([1, 2, 3, 4, v])
self.assertEqual([("4", v)], l._checkpoint_dependencies)
root = util.Checkpoint(l=l)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
@ -506,7 +506,7 @@ class ListWrapperTest(test.TestCase):
self.assertAllClose(1., v.numpy())
def testSort(self):
l = data_structures._ListWrapper([1, 2, 3, 4])
l = data_structures.ListWrapper([1, 2, 3, 4])
l.sort()
self.assertEqual(l, [1, 2, 3, 4])
# Regardless of being a no-op for the input list, we still refuse to save.
@ -708,7 +708,7 @@ class MappingTests(test.TestCase):
# methods/properties on the object. So the options are either not to
# subclass dict (in which case update will call normal iter methods, but the
# object won't pass isinstance checks) or to subclass dict and keep that
# storage updated (no shadowing all its methods like _ListWrapper).
# storage updated (no shadowing all its methods like ListWrapper).
new_dict.update(model.d)
self.assertEqual({1: 3}, new_dict)
@ -831,14 +831,14 @@ class MappingTests(test.TestCase):
def testListAddOrder(self):
self.assertEqual([1., 2.],
data_structures._ListWrapper([1.])
+ data_structures._ListWrapper([2.]))
data_structures.ListWrapper([1.])
+ data_structures.ListWrapper([2.]))
self.assertEqual([1., 2.],
data_structures._ListWrapper([1.])
data_structures.ListWrapper([1.])
+ [2.])
self.assertEqual([1., 2.],
[1.]
+ data_structures._ListWrapper([2.]))
+ data_structures.ListWrapper([2.]))
def testSameStructure(self):
d = {1: "a"}

View File

@ -41,6 +41,7 @@ def has_weights(obj):
def filter_empty_layer_containers(layer_list):
"""Filter out empty Layer-like containers and uniquify."""
# TODO(b/130381733): Make this an attribute in base_layer.Layer.
existing = object_identity.ObjectIdentitySet()
to_visit = layer_list[::-1]
filtered = []

View File

@ -94,7 +94,7 @@ class AutoTrackable(base.Trackable):
"""Override to allow TrackableBase to disable dependency tracking."""
return data_structures.NoDependency(value)
def _list_functions_for_serialization(self):
def _list_functions_for_serialization(self, unused_serialization_cache):
"""Return a dict of `Function`s of a trackable."""
functions = {}
for attribute_name in dir(self):
@ -192,7 +192,7 @@ class CapturableResource(base.Trackable):
self._resource_handle = self._create_resource()
return self._resource_handle
def _list_functions_for_serialization(self):
def _list_functions_for_serialization(self, unused_functions):
@def_function.function(input_signature=[], autograph=False)
def _creator():
resource = self._create_resource()

View File

@ -6,4 +6,12 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -6,4 +6,12 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -6,4 +6,12 @@ tf_class {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}