Keras models and layers saving and reviving code. Implements go/tf-model-serialization.
To save and revive a model: 1. Save the model using tf.saved_model.save 2. call load_from_save_model_v2 This restores various metadata about Keras models and layers, as well as their call and loss functions. Changes to object serialization: - Adds private fields for tracking object's identifier and metadata. - Added _list_extra_dependencies_for_serialization, which allows objects to save extra dependencies when serialized to SavedModel. - Object graph view maintains a serialization cache object that is passed to each object when serializing functions/extra dependencies. PiperOrigin-RevId: 251386039
This commit is contained in:
parent
ece5314ddb
commit
eff4ae822a
tensorflow
@ -74,6 +74,8 @@ message SavedUserObject {
|
||||
string identifier = 1;
|
||||
// Version information from the producer of this SavedUserObject.
|
||||
VersionDef version = 2;
|
||||
// Initialization-related metadata.
|
||||
string metadata = 3;
|
||||
}
|
||||
|
||||
// A SavedAsset points to an asset in the MetaGraph.
|
||||
|
@ -160,7 +160,6 @@ py_library(
|
||||
"engine/__init__.py",
|
||||
"engine/base_layer.py",
|
||||
"engine/input_layer.py",
|
||||
"engine/input_spec.py",
|
||||
"engine/network.py",
|
||||
"engine/partial_batch_padding_handler.py",
|
||||
"engine/saving.py",
|
||||
@ -185,6 +184,7 @@ py_library(
|
||||
":constraints",
|
||||
":engine_utils",
|
||||
":initializers",
|
||||
":input_spec",
|
||||
":losses",
|
||||
":mode_keys",
|
||||
":optimizers",
|
||||
@ -210,6 +210,20 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "input_spec",
|
||||
srcs = ["engine/input_spec.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":backend",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saving",
|
||||
srcs = [
|
||||
@ -224,12 +238,18 @@ py_library(
|
||||
deps = [
|
||||
":backend",
|
||||
":engine_utils",
|
||||
":input_spec",
|
||||
":mode_keys",
|
||||
":optimizers",
|
||||
":regularizers",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:saver",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/saved_model",
|
||||
"//tensorflow/python/saved_model/model_utils",
|
||||
"//tensorflow/python/training/tracking",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1239,6 +1259,17 @@ tf_py_test(
|
||||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "input_spec_test",
|
||||
size = "small",
|
||||
srcs = ["engine/input_spec_test.py"],
|
||||
additional_deps = [
|
||||
":keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "training_test",
|
||||
size = "medium",
|
||||
@ -1511,6 +1542,7 @@ tf_py_test(
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
shard_count = 4,
|
||||
tags = [
|
||||
"no_oss", # TODO(b/119349471): Re-enable
|
||||
"no_windows",
|
||||
|
@ -22,6 +22,7 @@ import collections
|
||||
import functools
|
||||
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
|
||||
import itertools
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
@ -46,9 +47,11 @@ from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.keras.saving import saved_model
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
# A module that only depends on `keras.layers` import these from here.
|
||||
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
|
||||
from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import
|
||||
from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import
|
||||
from tensorflow.python.module import module
|
||||
@ -64,6 +67,7 @@ from tensorflow.python.training.tracking import tracking
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
@ -2131,6 +2135,109 @@ class Layer(module.Module):
|
||||
return ('mask' in self._call_fn_args or
|
||||
getattr(self, 'compute_mask', None) is not None)
|
||||
|
||||
@property
|
||||
def _object_identifier(self):
|
||||
"""String stored in object identifier field in the SavedModel proto.
|
||||
|
||||
Returns:
|
||||
A string with the object identifier, which is used at load time.
|
||||
"""
|
||||
return '_tf_keras_layer'
|
||||
|
||||
@property
|
||||
def _tracking_metadata(self):
|
||||
"""String stored in metadata field in the SavedModel proto.
|
||||
|
||||
Returns:
|
||||
A serialized JSON storing information necessary for recreating this layer.
|
||||
"""
|
||||
# TODO(kathywu): Add support for metrics serialization.
|
||||
# TODO(kathywu): Synchronize with the keras spec (go/keras-json-spec) once
|
||||
# the python config serialization has caught up.
|
||||
|
||||
# Create a dictionary containing python layer state attributes. Any new
|
||||
# attribute that impacts the layer execution in some way should be added to
|
||||
# this dict.
|
||||
# Unlike a model's JSON configuration, which only
|
||||
# contains class_name and each layer's get_config() object, this stores more
|
||||
# information to accurately recreate the layer.
|
||||
# For backwards compatibility, any changes to this list should be additive.
|
||||
# Modifying or removing attributes may only be done with a sufficient
|
||||
# explanation.
|
||||
|
||||
metadata = dict(
|
||||
class_name=type(self).__name__,
|
||||
name=self.name,
|
||||
trainable=self.trainable,
|
||||
expects_training_arg=self._expects_training_arg,
|
||||
dtype=self.dtype,
|
||||
batch_input_shape=getattr(self, '_batch_input_shape', None))
|
||||
|
||||
try:
|
||||
# Store the config dictionary, which is only used by the revived object
|
||||
# to return the original config when revived_obj.get_config() is called.
|
||||
# It is not important for recreating the revived object.
|
||||
metadata['config'] = self.get_config()
|
||||
except NotImplementedError:
|
||||
# in the case of a subclassed model, the get_config() method will throw
|
||||
# a NotImplementedError.
|
||||
pass
|
||||
if self.input_spec is not None:
|
||||
metadata['input_spec'] = nest.map_structure(
|
||||
lambda x: x.get_config(), self.input_spec)
|
||||
else:
|
||||
metadata['input_spec'] = None
|
||||
if (self.activity_regularizer is not None and
|
||||
hasattr(self.activity_regularizer, 'get_config')):
|
||||
metadata['activity_regularizer'] = serialize_keras_object(
|
||||
self.activity_regularizer)
|
||||
else:
|
||||
metadata['activity_regularizer'] = None
|
||||
return json.dumps(metadata, default=serialization.get_json_type)
|
||||
|
||||
def _list_extra_dependencies_for_serialization(self, serialization_cache):
|
||||
"""Lists extra dependencies to serialize to SavedModel.
|
||||
|
||||
By overriding this method, extra dependencies can be attached to the
|
||||
serialized Layer. For example, this is used to save the list of `variables`
|
||||
and `trainable_variables`, which are python properties in a Layer object,
|
||||
but are represented as a static list in the SavedModel.
|
||||
|
||||
Args:
|
||||
serialization_cache: A dictionary shared between all objects in the same
|
||||
object graph. This object is passed to both
|
||||
`_list_extra_dependencies_for_serialization` and
|
||||
`_list_functions_for_serialization`.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping attribute names to trackable objects. The entire list
|
||||
of attributes are listed in the `saved_model._LayerAttributes` class.
|
||||
"""
|
||||
return (saved_model.serialize_all_attributes(self, serialization_cache)
|
||||
.objects_to_serialize)
|
||||
|
||||
def _list_functions_for_serialization(self, serialization_cache):
|
||||
"""Lists the functions to include when serializing a Layer.
|
||||
|
||||
Args:
|
||||
serialization_cache: Dictionary passed to all objects in the same object
|
||||
graph during serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping attribute names to `Function` or
|
||||
`ConcreteFunction`. The entire list of attributes are listed in the
|
||||
`saved_model._LayerAttributes` class.
|
||||
"""
|
||||
# Create a dictionary containing the layer's call and loss functions.
|
||||
fns = (saved_model.serialize_all_attributes(self, serialization_cache)
|
||||
.functions_to_serialize)
|
||||
# The parent Autotrackable class saves all user-defined tf.functions, and
|
||||
# returns them in _list_functions_for_serialization(). Add these functions
|
||||
# to the dict.
|
||||
fns.update(super(Layer, self)._list_functions_for_serialization(
|
||||
serialization_cache))
|
||||
return fns
|
||||
|
||||
|
||||
class Node(object):
|
||||
"""A `Node` describes the connectivity between two layers.
|
||||
|
@ -289,6 +289,9 @@ def is_in_eager_or_tf_function():
|
||||
|
||||
def is_in_tf_function():
|
||||
"""Returns if inside of a tf.function."""
|
||||
# Check if running in V1 graph mode.
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
return False
|
||||
if not ops.inside_function():
|
||||
return False
|
||||
# Check if inside Keras FuncGraph.
|
||||
|
@ -20,12 +20,16 @@ from __future__ import print_function
|
||||
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@keras_export('keras.layers.InputSpec', v1=['keras.layers.InputSpec'])
|
||||
@keras_export('keras.layers.InputSpec')
|
||||
@tf_export(v1=['layers.InputSpec'])
|
||||
class InputSpec(object):
|
||||
"""Specifies the ndim, dtype and shape of every input to a layer.
|
||||
@ -54,15 +58,27 @@ class InputSpec(object):
|
||||
max_ndim=None,
|
||||
min_ndim=None,
|
||||
axes=None):
|
||||
self.dtype = dtype
|
||||
self.shape = shape
|
||||
self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
|
||||
if shape is not None:
|
||||
self.ndim = len(shape)
|
||||
self.shape = shape
|
||||
else:
|
||||
self.ndim = ndim
|
||||
self.shape = None
|
||||
self.max_ndim = max_ndim
|
||||
self.min_ndim = min_ndim
|
||||
self.axes = axes or {}
|
||||
try:
|
||||
axes = axes or {}
|
||||
self.axes = {int(k): axes[k] for k in axes}
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError('The keys in axes must be integers.')
|
||||
|
||||
if self.axes and (self.ndim is not None or self.max_ndim is not None):
|
||||
max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
|
||||
max_axis = max(self.axes)
|
||||
if max_axis > max_dim:
|
||||
raise ValueError('Axis {} is greater than the maximum allowed value: {}'
|
||||
.format(max_axis, max_dim))
|
||||
|
||||
def __repr__(self):
|
||||
spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
|
||||
@ -73,6 +89,42 @@ class InputSpec(object):
|
||||
('axes=' + str(self.axes)) if self.axes else '']
|
||||
return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'dtype': self.dtype,
|
||||
'shape': self.shape,
|
||||
'ndim': self.ndim,
|
||||
'max_ndim': self.max_ndim,
|
||||
'min_ndim': self.min_ndim,
|
||||
'axes': self.axes}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
return cls(**config)
|
||||
|
||||
|
||||
def to_tensor_shape(spec):
|
||||
"""Returns a tf.TensorShape object that matches the shape specifications.
|
||||
|
||||
If the InputSpec's shape or ndim is defined, this method will return a fully
|
||||
or partially-known shape. Otherwise, the returned TensorShape is None.
|
||||
|
||||
Args:
|
||||
spec: an InputSpec object.
|
||||
|
||||
Returns:
|
||||
a tf.TensorShape object
|
||||
"""
|
||||
if spec.ndim is None and spec.shape is None:
|
||||
return tensor_shape.TensorShape(None)
|
||||
elif spec.shape is not None:
|
||||
return tensor_shape.TensorShape(spec.shape)
|
||||
else:
|
||||
shape = [None] * spec.ndim
|
||||
for a in spec.axes:
|
||||
shape[a] = spec.axes[a] # Assume that axes is defined
|
||||
return tensor_shape.TensorShape(shape)
|
||||
|
||||
|
||||
def assert_input_compatibility(input_spec, inputs, layer_name):
|
||||
"""Checks compatibility between the layer and provided inputs.
|
||||
@ -168,3 +220,12 @@ def assert_input_compatibility(input_spec, inputs, layer_name):
|
||||
' is incompatible with layer ' + layer_name +
|
||||
': expected shape=' + str(spec.shape) +
|
||||
', found shape=' + str(shape))
|
||||
|
||||
|
||||
def to_tensor_spec(input_spec, default_dtype=None):
|
||||
"""Converts a Keras InputSpec object to a TensorSpec."""
|
||||
default_dtype = default_dtype or backend.floatx()
|
||||
if isinstance(input_spec, InputSpec):
|
||||
dtype = input_spec.dtype or default_dtype
|
||||
return tensor_spec.TensorSpec(to_tensor_shape(input_spec), dtype)
|
||||
return tensor_spec.TensorSpec(None, default_dtype)
|
||||
|
66
tensorflow/python/keras/engine/input_spec_test.py
Normal file
66
tensorflow/python/keras/engine/input_spec_test.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""InputSpec tests."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class InputSpecTest(test.TestCase):
|
||||
|
||||
def test_axes_initialization(self):
|
||||
input_spec.InputSpec(shape=[1, None, 2, 3], axes={3: 5, '2': 2})
|
||||
with self.assertRaisesRegexp(ValueError, 'Axis 4 is greater than'):
|
||||
input_spec.InputSpec(shape=[1, None, 2, 3], axes={4: 5})
|
||||
with self.assertRaisesRegexp(TypeError, 'keys in axes must be integers'):
|
||||
input_spec.InputSpec(shape=[1, None, 2, 3], axes={'string': 5})
|
||||
|
||||
|
||||
class InputSpecToTensorShapeTest(test.TestCase):
|
||||
|
||||
def test_defined_shape(self):
|
||||
spec = input_spec.InputSpec(shape=[1, None, 2, 3])
|
||||
self.assertAllEqual(
|
||||
[1, None, 2, 3], input_spec.to_tensor_shape(spec).as_list())
|
||||
|
||||
def test_defined_ndims(self):
|
||||
spec = input_spec.InputSpec(ndim=5)
|
||||
self.assertAllEqual(
|
||||
[None] * 5, input_spec.to_tensor_shape(spec).as_list())
|
||||
|
||||
spec = input_spec.InputSpec(ndim=0)
|
||||
self.assertAllEqual(
|
||||
[], input_spec.to_tensor_shape(spec).as_list())
|
||||
|
||||
spec = input_spec.InputSpec(ndim=3, axes={1: 3, -1: 2})
|
||||
self.assertAllEqual(
|
||||
[None, 3, 2], input_spec.to_tensor_shape(spec).as_list())
|
||||
|
||||
def test_undefined_shapes(self):
|
||||
spec = input_spec.InputSpec(max_ndim=5)
|
||||
with self.assertRaisesRegexp(ValueError, 'unknown TensorShape'):
|
||||
input_spec.to_tensor_shape(spec).as_list()
|
||||
|
||||
spec = input_spec.InputSpec(min_ndim=5, max_ndim=5)
|
||||
with self.assertRaisesRegexp(ValueError, 'unknown TensorShape'):
|
||||
input_spec.to_tensor_shape(spec).as_list()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -1663,6 +1663,10 @@ class Network(base_layer.Layer):
|
||||
'inputs or `build()` is called with an `input_shape`.' %
|
||||
self.name)
|
||||
|
||||
@property
|
||||
def _object_identifier(self):
|
||||
return '_tf_keras_network'
|
||||
|
||||
|
||||
def _is_hdf5_filepath(filepath):
|
||||
return (filepath.endswith('.h5') or filepath.endswith('.keras') or
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
@ -55,6 +56,7 @@ from tensorflow.python.ops.losses import util as tf_losses_utils
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -1707,26 +1709,6 @@ class Model(network.Network):
|
||||
batch_size = 32
|
||||
return batch_size
|
||||
|
||||
def _list_functions_for_serialization(self):
|
||||
"""If available, saves a trace of call using self.inputs."""
|
||||
all_functions = super(Model, self)._list_functions_for_serialization()
|
||||
try:
|
||||
# pylint:disable=pointless-statement
|
||||
self.inputs
|
||||
self.input_names
|
||||
# pylint:enable=pointless-statement
|
||||
except AttributeError:
|
||||
# If the model does not have inputs set, because it was not called or its
|
||||
# input shapes were not recorded, we won't have a signature so can't trace
|
||||
# a function. But the user may still save an object with this Model
|
||||
# attached; we won't fail the whole tf.saved_model.save.
|
||||
pass
|
||||
else:
|
||||
if '_default_save_signature' not in all_functions:
|
||||
all_functions['_default_save_signature'] = (
|
||||
saving_utils.trace_model_call(self))
|
||||
return all_functions
|
||||
|
||||
def _prepare_sample_weights(self, sample_weights=None):
|
||||
"""Sets sample weight attribute on the model."""
|
||||
# List with the same length as model outputs.
|
||||
@ -2717,6 +2699,17 @@ class Model(network.Network):
|
||||
initial_epoch, mode)
|
||||
return initial_epoch
|
||||
|
||||
@property
|
||||
def _object_identifier(self):
|
||||
return '_tf_keras_model'
|
||||
|
||||
@property
|
||||
def _tracking_metadata(self):
|
||||
metadata = json.loads(super(Model, self)._tracking_metadata)
|
||||
metadata.update(saving_utils.model_metadata(
|
||||
self, include_optimizer=True, require_config=False))
|
||||
return json.dumps(metadata, default=serialization.get_json_type)
|
||||
|
||||
def _assert_compile_was_called(self):
|
||||
# Checks whether `compile` has been called. If it has been called,
|
||||
# then the optimizer is set. This is different from whether the
|
||||
|
@ -1010,7 +1010,7 @@ def _get_slot_key_from_var(var, slot_name):
|
||||
return name + "/" + slot_name
|
||||
|
||||
|
||||
class _RestoredOptimizer(OptimizerV2):
|
||||
class RestoredOptimizer(OptimizerV2):
|
||||
"""A non-functional Optimizer implementation for checkpoint compatibility.
|
||||
|
||||
Holds slot variables and hyperparameters when an optimizer is restored from a
|
||||
@ -1022,7 +1022,7 @@ class _RestoredOptimizer(OptimizerV2):
|
||||
# methods.
|
||||
|
||||
def __init__(self):
|
||||
super(_RestoredOptimizer, self).__init__("_RestoredOptimizer")
|
||||
super(RestoredOptimizer, self).__init__("RestoredOptimizer")
|
||||
self._hypers_created = True
|
||||
|
||||
def get_config(self):
|
||||
@ -1036,9 +1036,9 @@ revived_types.register_revived_type(
|
||||
"optimizer",
|
||||
lambda obj: isinstance(obj, OptimizerV2),
|
||||
versions=[revived_types.VersionedTypeRegistration(
|
||||
object_factory=lambda proto: _RestoredOptimizer(),
|
||||
object_factory=lambda proto: RestoredOptimizer(),
|
||||
version=1,
|
||||
min_producer_version=1,
|
||||
min_consumer_version=1,
|
||||
setter=_RestoredOptimizer._set_hyper # pylint: disable=protected-access
|
||||
setter=RestoredOptimizer._set_hyper # pylint: disable=protected-access
|
||||
)])
|
||||
|
@ -28,6 +28,7 @@ from six.moves import zip # pylint: disable=redefined-builtin
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.saving import model_config as model_config_lib
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.utils import conv_utils
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -71,8 +72,6 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
|
||||
if h5py is None:
|
||||
raise ImportError('`save_model` requires h5py.')
|
||||
|
||||
from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
|
||||
|
||||
# TODO(psv) Add warning when we save models that contain non-serializable
|
||||
# entities like metrics added using `add_metric` and losses added using
|
||||
# `add_loss.`
|
||||
@ -91,48 +90,21 @@ def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
|
||||
opened_new_file = False
|
||||
|
||||
try:
|
||||
f.attrs['keras_version'] = str(keras_version).encode('utf8')
|
||||
f.attrs['backend'] = K.backend().encode('utf8')
|
||||
f.attrs['model_config'] = json.dumps(
|
||||
{
|
||||
'class_name': model.__class__.__name__,
|
||||
'config': model.get_config()
|
||||
},
|
||||
default=serialization.get_json_type).encode('utf8')
|
||||
model_metadata = saving_utils.model_metadata(model, include_optimizer)
|
||||
for k, v in model_metadata.items():
|
||||
f.attrs[k] = json.dumps(
|
||||
v, default=serialization.get_json_type).encode('utf8')
|
||||
|
||||
model_weights_group = f.create_group('model_weights')
|
||||
model_layers = model.layers
|
||||
save_weights_to_hdf5_group(model_weights_group, model_layers)
|
||||
|
||||
if include_optimizer and model.optimizer:
|
||||
if isinstance(model.optimizer, optimizers.TFOptimizer):
|
||||
logging.warning(
|
||||
'TensorFlow optimizers do not '
|
||||
'make it possible to access '
|
||||
'optimizer attributes or optimizer state '
|
||||
'after instantiation. '
|
||||
'As a result, we cannot save the optimizer '
|
||||
'as part of the model save file. '
|
||||
'You will have to compile your model again after loading it. '
|
||||
'Prefer using a Keras optimizer instead '
|
||||
'(see keras.io/optimizers).')
|
||||
else:
|
||||
f.attrs['training_config'] = json.dumps(
|
||||
{
|
||||
'optimizer_config': {
|
||||
'class_name': model.optimizer.__class__.__name__,
|
||||
'config': model.optimizer.get_config()
|
||||
},
|
||||
'loss': model.loss,
|
||||
'metrics': model._compile_metrics,
|
||||
'weighted_metrics': model._compile_weighted_metrics,
|
||||
'sample_weight_mode': model.sample_weight_mode,
|
||||
'loss_weights': model.loss_weights,
|
||||
},
|
||||
default=serialization.get_json_type).encode('utf8')
|
||||
# TODO(b/128683857): Add integration tests between tf.keras and external
|
||||
# Keras, to avoid breaking TF.js users.
|
||||
if (include_optimizer and model.optimizer and
|
||||
not isinstance(model.optimizer, optimizers.TFOptimizer)):
|
||||
save_optimizer_weights_to_hdf5_group(f, model.optimizer)
|
||||
|
||||
# Save optimizer weights.
|
||||
save_optimizer_weights_to_hdf5_group(f, model.optimizer)
|
||||
f.flush()
|
||||
finally:
|
||||
if opened_new_file:
|
||||
|
@ -60,6 +60,14 @@ def save_model(model,
|
||||
the exact same state, without any of the code
|
||||
used for model definition or training.
|
||||
|
||||
_SavedModel serialization_ (not yet added)
|
||||
|
||||
The SavedModel serialization path uses `tf.saved_model.save` to save the model
|
||||
and all trackable objects attached to the model (e.g. layers and variables).
|
||||
`@tf.function`-decorated methods are also saved. Additional trackable objects
|
||||
and functions are added to the SavedModel to allow the model to be
|
||||
loaded back as a Keras Model object.
|
||||
|
||||
Arguments:
|
||||
model: Keras model instance to be saved.
|
||||
filepath: One of the following:
|
||||
@ -147,7 +155,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
|
||||
|
||||
if isinstance(filepath, six.string_types):
|
||||
loader_impl.parse_saved_model(filepath)
|
||||
return saved_model.load_from_saved_model(filepath)
|
||||
return saved_model.load_from_saved_model_v2(filepath)
|
||||
|
||||
raise IOError(
|
||||
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
|
||||
|
@ -12,37 +12,77 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=protected-access
|
||||
"""Utility functions to save/load keras Model to/from SavedModel."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import six
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.keras.saving import model_from_json
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.utils import mode_keys
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import builder as saved_model_builder
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import model_utils
|
||||
from tensorflow.python.saved_model import save as save_lib
|
||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||
from tensorflow.python.training.tracking.tracking import AutoTrackable
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
# To avoid circular dependencies between keras/engine and keras/saving,
|
||||
# code in keras/saving must delay imports.
|
||||
|
||||
# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
|
||||
# once the issue with copybara is fixed.
|
||||
# pylint:disable=g-inconsistent-quotes
|
||||
metrics_lib = LazyLoader("metrics_lib", globals(),
|
||||
"tensorflow.python.keras.metrics")
|
||||
models_lib = LazyLoader("models_lib", globals(),
|
||||
"tensorflow.python.keras.models")
|
||||
base_layer = LazyLoader(
|
||||
"base_layer", globals(),
|
||||
"tensorflow.python.keras.engine.base_layer")
|
||||
network_lib = LazyLoader(
|
||||
"network_lib", globals(),
|
||||
"tensorflow.python.keras.engine.network")
|
||||
sequential = LazyLoader(
|
||||
"sequential", globals(),
|
||||
"tensorflow.python.keras.engine.sequential")
|
||||
training_lib = LazyLoader(
|
||||
"training_lib", globals(),
|
||||
"tensorflow.python.keras.engine.training")
|
||||
# pylint:enable=g-inconsistent-quotes
|
||||
|
||||
|
||||
@keras_export('keras.experimental.export_saved_model')
|
||||
def export_saved_model(model,
|
||||
@ -144,9 +184,7 @@ def _export_model_variables(model, saved_model_path):
|
||||
|
||||
def _save_v1_format(model, path, custom_objects, as_text, input_signature):
|
||||
"""Exports model to v1 SavedModel format."""
|
||||
from tensorflow.python.keras.engine import sequential # pylint: disable=g-import-not-at-top
|
||||
|
||||
if not model._is_graph_network:
|
||||
if not model._is_graph_network: # pylint: disable=protected-access
|
||||
if isinstance(model, sequential.Sequential):
|
||||
# If input shape is not directly set in the model, the exported model
|
||||
# will infer the expected shapes of the input from the model.
|
||||
@ -163,7 +201,7 @@ def _save_v1_format(model, path, custom_objects, as_text, input_signature):
|
||||
'Subclassed models can only be exported for serving. Please set '
|
||||
'argument serving_only=True.')
|
||||
|
||||
builder = saved_model_builder._SavedModelBuilder(path)
|
||||
builder = saved_model_builder._SavedModelBuilder(path) # pylint: disable=protected-access
|
||||
|
||||
# Manually save variables to export them in an object-based checkpoint. This
|
||||
# skips the `builder.add_meta_graph_and_variables()` step, which saves a
|
||||
@ -233,7 +271,6 @@ def _export_mode(
|
||||
ValueError: If the train/eval mode is being exported, but the model does
|
||||
not have an optimizer.
|
||||
"""
|
||||
from tensorflow.python.keras import models as models_lib # pylint: disable=g-import-not-at-top
|
||||
compile_clone = (mode != mode_keys.ModeKeys.PREDICT)
|
||||
if compile_clone and not model.optimizer:
|
||||
raise ValueError(
|
||||
@ -264,12 +301,12 @@ def _export_mode(
|
||||
# Extract update and train ops from train/test/predict functions.
|
||||
train_op = None
|
||||
if mode == mode_keys.ModeKeys.TRAIN:
|
||||
clone._make_train_function()
|
||||
clone._make_train_function() # pylint: disable=protected-access
|
||||
train_op = clone.train_function.updates_op
|
||||
elif mode == mode_keys.ModeKeys.TEST:
|
||||
clone._make_test_function()
|
||||
clone._make_test_function() # pylint: disable=protected-access
|
||||
else:
|
||||
clone._make_predict_function()
|
||||
clone._make_predict_function() # pylint: disable=protected-access
|
||||
g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
|
||||
|
||||
with session.Session().as_default():
|
||||
@ -292,7 +329,7 @@ def _export_mode(
|
||||
# Add graph and variables to SavedModel.
|
||||
# TODO(b/113134168): Switch to add_meta_graph_and_variables.
|
||||
clone.save_weights(checkpoint_path, save_format='tf', overwrite=True)
|
||||
builder._has_saved_variables = True
|
||||
builder._has_saved_variables = True # pylint: disable=protected-access
|
||||
|
||||
# Add graph to the SavedModel builder.
|
||||
builder.add_meta_graph(
|
||||
@ -313,7 +350,7 @@ def _create_signature_def_map(model, mode):
|
||||
inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)}
|
||||
if model.optimizer:
|
||||
targets_dict = {x.name.split(':')[0]: x
|
||||
for x in model._targets if x is not None}
|
||||
for x in model._targets if x is not None} # pylint: disable=protected-access
|
||||
inputs_dict.update(targets_dict)
|
||||
outputs_dict = {name: x
|
||||
for name, x in zip(model.output_names, model.outputs)}
|
||||
@ -325,9 +362,8 @@ def _create_signature_def_map(model, mode):
|
||||
local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
|
||||
vars_to_add = set()
|
||||
if metrics is not None:
|
||||
from tensorflow.python.keras.metrics import Metric # pylint: disable=g-import-not-at-top
|
||||
for key, value in six.iteritems(metrics):
|
||||
if isinstance(value, Metric):
|
||||
if isinstance(value, metrics_lib.Metric):
|
||||
vars_to_add.update(value.variables)
|
||||
# Convert Metric instances to (value_tensor, update_op) tuple.
|
||||
metrics[key] = (value.result(), value.updates[0])
|
||||
@ -406,3 +442,823 @@ def load_from_saved_model(saved_model_path, custom_objects=None):
|
||||
compat.as_text(constants.VARIABLES_FILENAME))
|
||||
model.load_weights(checkpoint_prefix)
|
||||
return model
|
||||
|
||||
################################################################################
|
||||
# Functional Style/V2 SavedModel functions #
|
||||
################################################################################
|
||||
|
||||
# All serialized attributes are listed within SerializedAttributes classes. See
|
||||
# the docstring in SerializedAttributes for more context
|
||||
|
||||
# All attributes are saved under the 'keras_api' namespace. Only common
|
||||
# endpoints are attached directly to the root object.
|
||||
_KERAS_ATTR = 'keras_api'
|
||||
# Keys for the serialization cache.
|
||||
# Maps to the keras serialization dict {Layer --> SerializedAttributes object}
|
||||
_KERAS_CACHE_KEY = 'keras_serialized_attributes'
|
||||
|
||||
|
||||
class SerializedAttributes(object):
|
||||
"""Class that tracks and validates all serialization attributes.
|
||||
|
||||
Keras models contain many Python-defined components. For example, the
|
||||
trainable_variable property lists the model's trainable variables by
|
||||
recursively retrieving the trainable variables from each of the child layers.
|
||||
Another example is model.call, a python function that calls child layers and
|
||||
adds ops to the backend graph.
|
||||
|
||||
Only Tensorflow checkpointable objects and functions can be serialized to
|
||||
SavedModel. Serializing a Keras model as-is results in a checkpointable object
|
||||
that does not resemble a Keras model at all. Thus, extra checkpointable
|
||||
objects and functions must be created during serialization.
|
||||
|
||||
**Defining new serialized attributes**
|
||||
Child classes should be defined using:
|
||||
SerializedAttributes.with_attributes(
|
||||
'name', checkpointable_objects=[...], functions=[...], copy_from=[...])
|
||||
This class is used to cache generated checkpointable objects and functions,
|
||||
ensuring that new objects and functions are generated a single time.
|
||||
|
||||
**Usage during serialization**
|
||||
Each Layer/Model object should have a corresponding instance of
|
||||
SerializedAttributes. Create a new instance by calling
|
||||
`SerializedAttributes.new(obj)`. Objects and functions may be saved using
|
||||
`.set_and_validate_checkpointable_objects`/`.set_and_and_validate_functions`.
|
||||
The properties `.checkpointable_objects` and `.functions` returns the cached
|
||||
values.
|
||||
|
||||
**Adding/changing attributes to save to SavedModel**
|
||||
1. Change the call to `SerializedAttributes.with_attributes` in the correct
|
||||
class:
|
||||
- CommonEndpoints: Base attributes to be added during serialization. If
|
||||
these attributes are present in a Trackable object, it can be
|
||||
deserialized to a Keras Model.
|
||||
- LayerAttributes: Attributes to serialize for Layer objects.
|
||||
- ModelAttributes: Attributes to serialize for Model objects.
|
||||
2. Update class docstring
|
||||
3. Update arguments to any calls to `set_and_validate_*`. For example, if
|
||||
`call_raw_tensors` is added to the ModelAttributes function list, then
|
||||
a `call_raw_tensors` function should be passed to
|
||||
`set_and_validate_functions`.
|
||||
|
||||
**Common endpoints vs other attributes**
|
||||
Only common endpoints are attached directly to the root object. Keras-specific
|
||||
attributes are saved to a separate trackable object with the name "keras_api".
|
||||
The number of objects attached to the root is limited because any naming
|
||||
conflicts will cause user code to break.
|
||||
|
||||
Another reason is that this will only affect users who call
|
||||
`tf.saved_model.load` instead of `tf.keras.models.load_model`. These are
|
||||
advanced users who are likely to have defined their own tf.functions and
|
||||
trackable objects. The added Keras-specific attributes are kept out of the way
|
||||
in the "keras_api" namespace.
|
||||
|
||||
Properties defined in this class may be used to filter out keras-specific
|
||||
attributes:
|
||||
- `functions_to_serialize`: Returns dict of functions to attach to the root
|
||||
object.
|
||||
- `checkpointable_objects_to_serialize`: Returns dict of objects to attach to
|
||||
the root object (including separate trackable object containing
|
||||
keras-specific attributes)
|
||||
|
||||
All changes to the serialized attributes must be backwards-compatible, so
|
||||
attributes should not be removed or modified without sufficient justification.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def with_attributes(
|
||||
name, checkpointable_objects=None, functions=None, copy_from=None):
|
||||
"""Creates a subclass with all attributes as specified in the arguments.
|
||||
|
||||
Args:
|
||||
name: Name of subclass
|
||||
checkpointable_objects: List of checkpointable objects to be serialized
|
||||
in the SavedModel.
|
||||
functions: List of functions to be serialized in the SavedModel.
|
||||
copy_from: List of other SerializedAttributes subclasses. The returend
|
||||
class will copy checkpoint objects/functions from each subclass.
|
||||
|
||||
Returns:
|
||||
Child class with attributes as defined in the `checkpointable_objects`
|
||||
and `functions` lists.
|
||||
"""
|
||||
checkpointable_objects = checkpointable_objects or []
|
||||
functions = functions or []
|
||||
|
||||
if copy_from is not None:
|
||||
for cls in copy_from:
|
||||
checkpointable_objects.extend(cls.all_checkpointable_objects)
|
||||
functions.extend(cls.all_functions)
|
||||
|
||||
classdict = {
|
||||
'all_checkpointable_objects': set(checkpointable_objects),
|
||||
'all_functions': set(functions)}
|
||||
return type(name, (SerializedAttributes,), classdict)
|
||||
|
||||
@staticmethod
|
||||
def new(obj):
|
||||
if isinstance(obj, training_lib.Model):
|
||||
return ModelAttributes()
|
||||
elif isinstance(obj, base_layer.Layer):
|
||||
return LayerAttributes()
|
||||
else:
|
||||
raise TypeError('Internal error during serialization: Expected Keras '
|
||||
'Layer object, got {} of type {}'.format(obj, type(obj)))
|
||||
|
||||
def __init__(self):
|
||||
self._object_dict = {}
|
||||
self._function_dict = {}
|
||||
self._keras_trackable = AutoTrackable()
|
||||
|
||||
@property
|
||||
def functions(self):
|
||||
"""Returns dictionary of all functions."""
|
||||
return {key: value for key, value in self._function_dict.items()
|
||||
if value is not None}
|
||||
|
||||
@property
|
||||
def checkpointable_objects(self):
|
||||
"""Returns dictionary of all checkpointable objects."""
|
||||
return {key: value for key, value in self._object_dict.items()
|
||||
if value is not None}
|
||||
|
||||
@property
|
||||
def functions_to_serialize(self):
|
||||
"""Returns functions to attach to the root object during serialization."""
|
||||
return {key: value for key, value in self.functions.items()
|
||||
if key in CommonEndpoints.all_functions}
|
||||
|
||||
@property
|
||||
def objects_to_serialize(self):
|
||||
"""Returns objects to attach to the root object during serialization."""
|
||||
objects = {key: value for key, value in self.checkpointable_objects.items()
|
||||
if key in CommonEndpoints.all_checkpointable_objects}
|
||||
objects[_KERAS_ATTR] = self._keras_trackable
|
||||
return objects
|
||||
|
||||
def set_and_validate_functions(self, function_dict):
|
||||
"""Saves function dictionary, and validates dictionary values."""
|
||||
for key in self.all_functions:
|
||||
if key in function_dict:
|
||||
if (function_dict[key] is not None and # Not all functions are required
|
||||
not isinstance(function_dict[key],
|
||||
(defun.Function, def_function.Function))):
|
||||
raise ValueError(
|
||||
'Function dictionary contained a non-function object: {} (for key'
|
||||
' {})'.format(function_dict[key], key))
|
||||
self._function_dict[key] = function_dict[key]
|
||||
setattr(self._keras_trackable, key, function_dict[key])
|
||||
else:
|
||||
raise ValueError('Function {} missing from serialized function dict.'
|
||||
.format(key))
|
||||
return self.functions
|
||||
|
||||
def set_and_validate_objects(self, object_dict):
|
||||
"""Saves objects to a dictionary, and validates the values."""
|
||||
for key in self.all_checkpointable_objects:
|
||||
if key in object_dict:
|
||||
if not isinstance(object_dict[key], trackable.Trackable):
|
||||
raise ValueError(
|
||||
'Object dictionary contained a non-trackable object: {} (for key'
|
||||
' {})'.format(object_dict[key], key))
|
||||
self._object_dict[key] = object_dict[key]
|
||||
setattr(self._keras_trackable, key, object_dict[key])
|
||||
else:
|
||||
raise ValueError('Object {} missing from serialized object dict.')
|
||||
return self.checkpointable_objects
|
||||
|
||||
|
||||
class CommonEndpoints(SerializedAttributes.with_attributes(
|
||||
'CommonEndpoints',
|
||||
checkpointable_objects=['variables', 'trainable_variables',
|
||||
'regularization_losses'],
|
||||
functions=['__call__', 'call_and_return_all_conditional_losses',
|
||||
'_default_save_signature'])):
|
||||
"""Common endpoints shared by all models loadable by Keras.
|
||||
|
||||
List of all attributes:
|
||||
variables: List of all variables in the model and its sublayers.
|
||||
trainable_variables: List of all trainable variables in the model and its
|
||||
sublayers.
|
||||
regulariation_losses: List of all unconditional losses (losses not dependent
|
||||
on the inputs) in the model and its sublayers.
|
||||
__call__: Function that takes inputs and returns the outputs of the model
|
||||
call function.
|
||||
call_and_return_all_conditional_losses: Function that returns a tuple of
|
||||
(call function outputs, list of all losses that depend on the inputs).
|
||||
_default_save_signature: Traced model call function. This is only included
|
||||
if the top level exported object is a Keras model.
|
||||
"""
|
||||
|
||||
|
||||
class LayerAttributes(SerializedAttributes.with_attributes(
|
||||
'LayerAttributes',
|
||||
checkpointable_objects=['non_trainable_variables', 'layers', 'metrics'],
|
||||
functions=['call_and_return_conditional_losses', 'activity_regularizer_fn'],
|
||||
copy_from=[CommonEndpoints]
|
||||
)):
|
||||
"""Layer checkpointable objects + functions that are saved to the SavedModel.
|
||||
|
||||
List of all attributes:
|
||||
All attributes from CommonEndpoints
|
||||
non_trainable_variables: List of non-trainable variables in the layer and
|
||||
its sublayers.
|
||||
layers: List of all sublayers.
|
||||
metrics: List of all metrics in the layer and its sublayers.
|
||||
call_and_return_conditional_losses: Function that takes inputs and returns a
|
||||
tuple of (outputs of the call function, list of input-dependent losses).
|
||||
The list of losses excludes the activity regularizer function, which is
|
||||
separate to allow the deserialized Layer object to define a different
|
||||
activity regularizer.
|
||||
activity_regularizer_fn: Callable that returns the activity regularizer loss
|
||||
"""
|
||||
|
||||
|
||||
class ModelAttributes(SerializedAttributes.with_attributes(
|
||||
'ModelAttributes',
|
||||
copy_from=[LayerAttributes])):
|
||||
"""Model checkpointable objects + functions that are saved to the SavedModel.
|
||||
|
||||
List of all attributes:
|
||||
All attributes from LayerAttributes (including CommonEndpoints)
|
||||
"""
|
||||
# TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, which
|
||||
# list all losses and metrics defined by `model.compile`.
|
||||
|
||||
|
||||
def serialize_all_attributes(layer, serialization_cache):
|
||||
"""Serialize all attributes in the layer."""
|
||||
save_model_default_signature = False
|
||||
if _KERAS_CACHE_KEY not in serialization_cache:
|
||||
keras_cache = serialization_cache[_KERAS_CACHE_KEY] = {}
|
||||
if isinstance(layer, training_lib.Model):
|
||||
# Only trace default signature if the root object is a Model. Since the
|
||||
# keras cache key is only created in this method, we know that the object
|
||||
# is root if the key does not yet exist in the cache.
|
||||
save_model_default_signature = True
|
||||
else:
|
||||
keras_cache = serialization_cache[_KERAS_CACHE_KEY]
|
||||
|
||||
if layer in keras_cache:
|
||||
return keras_cache[layer]
|
||||
serialized_attr = keras_cache[layer] = SerializedAttributes.new(layer)
|
||||
|
||||
if _should_skip_serialization(layer):
|
||||
return serialized_attr
|
||||
|
||||
object_dict = _wrap_layer_objects(layer, serialization_cache)
|
||||
try:
|
||||
function_dict = _wrap_layer_functions(layer, serialization_cache,
|
||||
save_model_default_signature)
|
||||
except (ValueError, TypeError) as e:
|
||||
logging.warning('Skipping full serialization of object {}, because an '
|
||||
'error occurred while tracing layer functions. Error '
|
||||
'message: {}'.format(layer, e.message))
|
||||
else:
|
||||
# Add checkpointable objects and functions to the SerializedAttribute object
|
||||
# only if all functions are successfully traced.
|
||||
# The `set_and_validate_*` function ensures that all required attributes are
|
||||
# exported with the correct type.
|
||||
serialized_attr.set_and_validate_objects(object_dict)
|
||||
serialized_attr.set_and_validate_functions(function_dict)
|
||||
return serialized_attr
|
||||
|
||||
|
||||
def _should_skip_serialization(layer):
|
||||
"""Skip serializing extra objects and functions if layer inputs aren't set."""
|
||||
if isinstance(layer, training_lib.Model):
|
||||
try:
|
||||
# pylint:disable=pointless-statement
|
||||
layer.inputs
|
||||
layer.input_names
|
||||
# pylint:enable=pointless-statement
|
||||
except AttributeError:
|
||||
# If the model does not have inputs set, because it was not called or its
|
||||
# input shapes were not recorded, we won't have a signature so can't trace
|
||||
# a function. But the user may still save an object with this Model
|
||||
# attached; we won't fail the whole tf.saved_model.save.
|
||||
logging.warning('Skipping full serialization of Keras model {}, because '
|
||||
'its inputs are not defined.'.format(layer))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
if not layer.input_spec:
|
||||
logging.warning('Skipping full serialization of Keras layer {}, because '
|
||||
'it does not have an input spec defined.'.format(layer))
|
||||
return True
|
||||
if not layer.built:
|
||||
logging.warning('Skipping full serialization of Keras layer {}, because '
|
||||
'it is not built.'.format(layer))
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _wrap_layer_objects(layer, serialization_cache):
|
||||
"""Returns extra trackable objects to attach to the serialized layer.
|
||||
|
||||
Args:
|
||||
layer: Keras Layer object.
|
||||
serialization_cache: Dictionary shared between all objects during
|
||||
serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing all checkpointable objects from a
|
||||
SerializedAttributes object. See LayerAttributes and ModelAttributes for
|
||||
entire list of objects
|
||||
"""
|
||||
# Wrap all regularization losses as tf.functions.
|
||||
# First, generate list of all regularization losses in this layer and
|
||||
# sublayers.
|
||||
regularization_losses = layer._callable_losses[:] # pylint: disable=protected-access
|
||||
for child_layer in (
|
||||
trackable_layer_utils.filter_empty_layer_containers(layer._layers)): # pylint: disable=protected-access
|
||||
regularization_losses.extend(child_layer._callable_losses) # pylint: disable=protected-access
|
||||
# Next, wrap all loss functions as tf.functions. Use the serialization cache
|
||||
# to store already-wrapped functions.
|
||||
keras_loss_cache = serialization_cache.setdefault('keras_losses', {})
|
||||
wrapped_loss_functions = []
|
||||
for loss_fn in regularization_losses:
|
||||
if loss_fn in keras_loss_cache:
|
||||
wrapped_loss_functions.append(keras_loss_cache[loss_fn])
|
||||
else:
|
||||
wrapped_loss = _wrap_unconditional_loss(loss_fn, len(keras_loss_cache))
|
||||
keras_loss_cache[wrapped_loss] = wrapped_loss
|
||||
wrapped_loss_functions.append(wrapped_loss)
|
||||
return dict(
|
||||
variables=data_structures.ListWrapper(layer.variables),
|
||||
trainable_variables=data_structures.ListWrapper(
|
||||
layer.trainable_variables),
|
||||
non_trainable_variables=data_structures.ListWrapper(
|
||||
layer.non_trainable_variables),
|
||||
layers=data_structures.ListWrapper(
|
||||
trackable_layer_utils.filter_empty_layer_containers(
|
||||
layer._layers)), # pylint: disable=protected-access
|
||||
metrics=data_structures.ListWrapper(layer.metrics),
|
||||
regularization_losses=data_structures.ListWrapper(
|
||||
wrapped_loss_functions))
|
||||
|
||||
|
||||
def _wrap_layer_functions(layer, serialization_cache,
|
||||
save_model_default_signature=False):
|
||||
"""Returns dict of wrapped layer call function and losses in tf.functions.
|
||||
|
||||
Args:
|
||||
layer: Keras Layer object.
|
||||
serialization_cache: Dictionary shared between all objects during
|
||||
serialization.
|
||||
save_model_default_signature: Whether to save traced model call function.
|
||||
|
||||
Returns:
|
||||
A dictionary containing all keras tf.functions to serialize. See
|
||||
LayerAttributes and ModelAttributes for the list of all attributes.
|
||||
"""
|
||||
# Reset the losses of the layer and its children. The call function in each
|
||||
# child layer is replaced with tf.functions.
|
||||
original_attrs = _replace_child_layer_functions(layer, serialization_cache)
|
||||
original_layer_losses = layer._losses[:] # pylint: disable=protected-access
|
||||
with trackable.no_automatic_dependency_tracking_scope(layer):
|
||||
layer._losses = [] # pylint: disable=protected-access
|
||||
# Note that eager losses do not need to be saved since these functions
|
||||
# create symbolic losses.
|
||||
|
||||
# Wrap all the layer call and activity regularizer functions.
|
||||
call_fn_with_losses = _wrap_call_and_conditional_losses(layer)
|
||||
fns = {'call_and_return_conditional_losses': call_fn_with_losses,
|
||||
'__call__': _extract_outputs_from_fn(layer, call_fn_with_losses)}
|
||||
|
||||
if save_model_default_signature:
|
||||
fns['_default_save_signature'] = saving_utils.trace_model_call(layer)
|
||||
else:
|
||||
fns['_default_save_signature'] = None
|
||||
|
||||
if layer.activity_regularizer is not None:
|
||||
fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
|
||||
fns['call_and_return_all_conditional_losses'] = (
|
||||
_append_activity_regularizer_loss(
|
||||
layer, call_fn_with_losses, fns['activity_regularizer_fn']))
|
||||
else:
|
||||
fns['activity_regularizer_fn'] = None
|
||||
fns['call_and_return_all_conditional_losses'] = call_fn_with_losses
|
||||
|
||||
# Manually trigger traces before restoring the overwritten functions. The
|
||||
# functions are traced within the layer call context to ensure that layer
|
||||
# functions (e.g. add_loss) behave as though running in graph mode.
|
||||
with base_layer_utils.call_context().enter(layer, None, build_graph=True):
|
||||
for fn in fns.values():
|
||||
if fn is not None and fn.input_signature is not None:
|
||||
fn.get_concrete_function()
|
||||
|
||||
# Restore overwritten functions/losses
|
||||
with trackable.no_automatic_dependency_tracking_scope(layer):
|
||||
layer._losses = original_layer_losses # pylint: disable=protected-access
|
||||
_restore_child_layer_functions(original_attrs)
|
||||
|
||||
return fns
|
||||
|
||||
|
||||
def _replace_child_layer_functions(layer, serialization_cache):
|
||||
"""Replaces functions in the children layers with wrapped tf.functions.
|
||||
|
||||
This step allows functions from parent layers to reference the wrapped
|
||||
functions from their children layers instead of retracing the ops.
|
||||
|
||||
This function also resets all losses stored in the layer. These are stored in
|
||||
the returned dictionary. Use `_restore_child_layer_functions` to restore
|
||||
the original attributes.
|
||||
|
||||
Args:
|
||||
layer: Keras Layer object.
|
||||
serialization_cache: Dictionary shared between all objects during
|
||||
serialization.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping layer objects -> original functions and losses:
|
||||
{ Child layer 1: {
|
||||
'losses': Original losses,
|
||||
'call': Original call function
|
||||
'activity_regularizer': Original activity regularizer},
|
||||
Child layer 2: ...
|
||||
}
|
||||
"""
|
||||
original_attrs = {}
|
||||
for child_layer in trackable_layer_utils.filter_empty_layer_containers(
|
||||
layer._layers): # pylint: disable=protected-access
|
||||
# Save symbolic layer losses, which will be restored to maintain the same
|
||||
# state.
|
||||
original_attrs[child_layer] = {'losses': child_layer._losses[:]} # pylint: disable=protected-access
|
||||
if child_layer not in serialization_cache[_KERAS_CACHE_KEY]:
|
||||
layer_fns = (serialize_all_attributes(child_layer, serialization_cache)
|
||||
.functions)
|
||||
else:
|
||||
layer_fns = serialization_cache[_KERAS_CACHE_KEY][child_layer].functions
|
||||
if not layer_fns:
|
||||
# This indicates either:
|
||||
# - circular dependency, which means the current layer's functions
|
||||
# should be wrapped first.
|
||||
# - Child layer's inputs are not defined, so its functions have not been
|
||||
# wrapped. In this case, no replacement is necessary so move on to the
|
||||
# next child.
|
||||
continue
|
||||
|
||||
original_attrs[child_layer]['call'] = child_layer.call
|
||||
original_attrs[child_layer]['activity_regularizer'] = (
|
||||
child_layer.activity_regularizer)
|
||||
with trackable.no_automatic_dependency_tracking_scope(child_layer):
|
||||
child_layer.activity_regularizer = layer_fns.get(
|
||||
'activity_regularizer_fn')
|
||||
child_layer.call = _use_wrapped_call(
|
||||
child_layer, layer_fns['call_and_return_conditional_losses'])
|
||||
child_layer._losses = [] # pylint: disable=protected-access
|
||||
return original_attrs
|
||||
|
||||
|
||||
def _restore_child_layer_functions(original_attrs):
|
||||
"""Restores attributes replaced with `_replace_child_layer_functions`."""
|
||||
for child_layer, attrs in original_attrs.items():
|
||||
with trackable.no_automatic_dependency_tracking_scope(child_layer):
|
||||
child_layer._losses = attrs['losses'] # pylint: disable=protected-access
|
||||
if 'call' in attrs:
|
||||
child_layer.call = attrs['call']
|
||||
child_layer.activity_regularizer = attrs['activity_regularizer']
|
||||
|
||||
|
||||
def _use_wrapped_call(layer, call_fn):
|
||||
"""Creates fn that adds the losses returned by call_fn & returns the outputs.
|
||||
|
||||
Args:
|
||||
layer: A Keras layer object
|
||||
call_fn: tf.function returned by _wrap_call_and_conditional_losses.
|
||||
|
||||
Returns:
|
||||
function that calls call_fn and returns the outputs. Losses returned by
|
||||
call_fn are added to the layer losses.
|
||||
"""
|
||||
def wrapped_call(inputs, *args, **kwargs):
|
||||
"""Returns the outputs from the call_fn, and adds the losses."""
|
||||
if layer._expects_training_arg: # pylint: disable=protected-access
|
||||
training = kwargs.pop('training', None)
|
||||
if training is None:
|
||||
training = K.learning_phase()
|
||||
training = math_ops.cast(training, dtypes.bool)
|
||||
outputs, losses = call_fn(inputs, training=training, *args, **kwargs)
|
||||
else:
|
||||
outputs, losses = call_fn(inputs)
|
||||
layer.add_loss(losses, inputs)
|
||||
return outputs
|
||||
return wrapped_call
|
||||
|
||||
|
||||
def _wrap_call_and_conditional_losses(layer):
|
||||
"""Wraps call function that returns a tuple of (outputs, losses).
|
||||
|
||||
The losses returned are conditional on the inputs passed to the call function.
|
||||
Unconditional losses (e.g. weight regularizeration) are wrapped separately.
|
||||
|
||||
Args:
|
||||
layer: a Keras layer object
|
||||
|
||||
Returns:
|
||||
call function that returns outputs and conditional losses -- excludes
|
||||
activity regularizer
|
||||
"""
|
||||
if isinstance(layer, RevivedLayer):
|
||||
return layer.call_and_return_conditional_losses
|
||||
|
||||
if (isinstance(layer.call, def_function.Function) and
|
||||
layer.call.input_signature is not None):
|
||||
input_signature = layer.call.input_signature
|
||||
else:
|
||||
if (isinstance(layer, training_lib.Model) and
|
||||
saving_utils.model_input_signature(layer) is not None):
|
||||
input_signature = saving_utils.model_input_signature(layer)
|
||||
else:
|
||||
input_signature = [nest.map_structure(
|
||||
lambda x: input_spec.to_tensor_spec(x, layer.dtype),
|
||||
layer.input_spec)]
|
||||
# If input spec is too general, then don't define an input signature.
|
||||
for spec in nest.flatten(input_signature):
|
||||
if spec.shape == tensor_shape.TensorShape(None):
|
||||
input_signature = None
|
||||
break
|
||||
|
||||
if input_signature is not None and layer._expects_training_arg: # pylint: disable=protected-access
|
||||
input_signature.append(
|
||||
tensor_spec.TensorSpec(shape=[], dtype=dtypes.bool))
|
||||
|
||||
# Create function that generates both outputs and losses
|
||||
layer_call = layer.call
|
||||
if layer._expects_training_arg: # pylint: disable=protected-access
|
||||
def call_and_return_conditional_losses(inputs, training):
|
||||
_set_symbolic_learning_phase(training)
|
||||
return layer_call(inputs, training=training), layer.get_losses_for(inputs)
|
||||
else:
|
||||
def call_and_return_conditional_losses(inputs):
|
||||
K.set_learning_phase(0)
|
||||
return layer_call(inputs), layer.get_losses_for(inputs)
|
||||
return def_function.Function(
|
||||
call_and_return_conditional_losses,
|
||||
'{}_layer_call_and_return_conditional_losses'.format(layer.name),
|
||||
input_signature=input_signature,
|
||||
# TODO(kathywu): Investigate autograph error.
|
||||
autograph=False)
|
||||
|
||||
|
||||
def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
|
||||
"""Returns a function that returns only call function outputs."""
|
||||
if isinstance(layer, RevivedLayer):
|
||||
return layer._original_call # pylint: disable=protected-access
|
||||
if layer._expects_training_arg: # pylint: disable=protected-access
|
||||
def call(inputs, training):
|
||||
return call_and_return_conditional_losses(inputs, training)[0]
|
||||
else:
|
||||
def call(inputs):
|
||||
return call_and_return_conditional_losses(inputs)[0]
|
||||
return def_function.Function(
|
||||
call, '{}_layer_call_fn'.format(layer.name),
|
||||
input_signature=call_and_return_conditional_losses.input_signature,
|
||||
# TODO(kathywu): Investigate autograph error.
|
||||
autograph=False)
|
||||
|
||||
|
||||
def _set_symbolic_learning_phase(value):
|
||||
"""Set learning phase to a tensor value (for internal use only).
|
||||
|
||||
This is used when wrapping call functions as tf.functions that have training
|
||||
as a tensor input. Thus, when `learning_phase()` is called, the training
|
||||
tensor is returned. This function is called when saving a model to SavedModel.
|
||||
|
||||
Args:
|
||||
value: A Tensor object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input value is not a graph tensor
|
||||
"""
|
||||
graph = K.get_graph()
|
||||
if not isinstance(value, ops.Tensor):
|
||||
raise ValueError('Symbolic learning phase must be a graph tensor.')
|
||||
K._GRAPH_LEARNING_PHASES[graph] = value # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _append_activity_regularizer_loss(
|
||||
layer, call_fn_with_losses, activity_regularizer_fn):
|
||||
"""Appends activity regularizer loss to losses returned by the wrapped fn."""
|
||||
def fn(*args):
|
||||
outputs, losses = call_fn_with_losses(*args)
|
||||
losses.append(activity_regularizer_fn(outputs))
|
||||
return outputs, losses
|
||||
return def_function.Function(
|
||||
fn,
|
||||
'{}_layer_call_and_return_all_conditional_losses'.format(layer.name),
|
||||
input_signature=call_fn_with_losses.input_signature,
|
||||
# TODO(kathywu): Investigate autograph error.
|
||||
autograph=False)
|
||||
|
||||
|
||||
def _wrap_unconditional_loss(loss_fn, index):
|
||||
"""Wraps callable/unconditonal loss, returning a serializable function."""
|
||||
# Extract original loss function from partial function
|
||||
fn = loss_fn.args[0] if isinstance(loss_fn, functools.partial) else loss_fn
|
||||
if isinstance(fn, def_function.Function):
|
||||
return fn
|
||||
else:
|
||||
return def_function.Function(
|
||||
fn, 'loss_fn_{}'.format(index), input_signature=[])
|
||||
|
||||
|
||||
def _wrap_activity_regularizer(layer):
|
||||
"""Wraps the activity regularizer."""
|
||||
if isinstance(layer.activity_regularizer, def_function.Function):
|
||||
return layer.activity_regularizer
|
||||
return def_function.Function(
|
||||
layer.activity_regularizer,
|
||||
'{}_activity_regularizer'.format(layer.name),
|
||||
input_signature=[tensor_spec.TensorSpec(None, layer.dtype or K.floatx())])
|
||||
|
||||
|
||||
def load_from_saved_model_v2(path):
|
||||
"""Loads Keras objects from a SavedModel.
|
||||
|
||||
Any Keras layer or model saved to the SavedModel will be loaded back
|
||||
as Keras objects. Other objects are loaded as regular trackable objects (same
|
||||
as `tf.saved_model.load`).
|
||||
|
||||
Currently, Keras saving/loading only retains the Keras object's weights,
|
||||
losses, and call function.
|
||||
|
||||
The loaded model can be re-compiled, but the original optimizer, compiled loss
|
||||
functions, and metrics are not retained. This is temporary, and `model.save`
|
||||
will soon be able to serialize compiled models.
|
||||
|
||||
Args:
|
||||
path: Path to SavedModel.
|
||||
|
||||
Returns:
|
||||
Object loaded from SavedModel.
|
||||
"""
|
||||
# TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
|
||||
# TODO(kathywu): Add code to load from objects that contain all endpoints
|
||||
return load.load_internal(path, loader_cls=KerasObjectLoader)
|
||||
|
||||
|
||||
class KerasObjectLoader(load.Loader):
|
||||
"""Loader that recreates Keras objects."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(KerasObjectLoader, self).__init__(*args, **kwargs)
|
||||
self._finalize()
|
||||
|
||||
def _finalize(self):
|
||||
# pylint: disable=protected-access
|
||||
for node in self._nodes:
|
||||
if isinstance(node, RevivedLayer):
|
||||
losses = node._serialized_attributes.get('regularization_losses', [])
|
||||
for loss in losses:
|
||||
node.add_loss(loss)
|
||||
|
||||
# Use wrapped activity regularizer function if the layer's activity
|
||||
# regularizer wasn't created during initialization.
|
||||
if node.activity_regularizer is None:
|
||||
node.activity_regularizer = getattr(node.keras_api,
|
||||
'activity_regularizer_fn', None)
|
||||
|
||||
if isinstance(node, RevivedModel):
|
||||
# Since this revived object is technically a subclassed model (even if
|
||||
# the original model is functional/sequential), inputs should be set.
|
||||
input_signature = (
|
||||
node.keras_api.call_and_return_conditional_losses.input_signature[0]
|
||||
)
|
||||
node._set_inputs(input_signature)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _recreate_base_user_object(self, proto):
|
||||
revived_classes = {
|
||||
'_tf_keras_layer': (RevivedLayer, base_layer.Layer),
|
||||
'_tf_keras_network': (RevivedNetwork, network_lib.Network),
|
||||
'_tf_keras_model': (RevivedModel, training_lib.Model)
|
||||
}
|
||||
|
||||
parent_classes = revived_classes.get(proto.identifier, None)
|
||||
|
||||
if parent_classes is not None:
|
||||
parent_classes = revived_classes[proto.identifier]
|
||||
metadata = json.loads(proto.metadata)
|
||||
revived_cls = type(
|
||||
compat.as_str(metadata['class_name']),
|
||||
parent_classes,
|
||||
{'__setattr__': parent_classes[1].__setattr__})
|
||||
obj = revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access
|
||||
return obj, revived_cls._revive_setter # pylint: disable=protected-access
|
||||
|
||||
return super(KerasObjectLoader, self)._recreate_base_user_object(proto)
|
||||
|
||||
|
||||
# TODO(kathywu): Centrally define keys and functions for both serialization and
|
||||
# deserialization.
|
||||
class RevivedLayer(object):
|
||||
"""Keras layer loaded from a SavedModel."""
|
||||
|
||||
@classmethod
|
||||
def _init_from_metadata(cls, metadata):
|
||||
"""Create revived layer from metadata stored in the SavedModel proto."""
|
||||
init_args = dict(
|
||||
name=metadata['name'],
|
||||
trainable=metadata['trainable'])
|
||||
if metadata.get('dtype') is not None:
|
||||
init_args['dtype'] = metadata['dtype']
|
||||
if metadata.get('batch_input_shape') is not None:
|
||||
init_args['batch_input_shape'] = metadata['batch_input_shape']
|
||||
|
||||
revived_obj = cls(**init_args)
|
||||
|
||||
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
||||
# pylint:disable=protected-access
|
||||
revived_obj._expects_training_arg = metadata['expects_training_arg']
|
||||
if metadata.get('config') is not None:
|
||||
revived_obj._config = metadata['config']
|
||||
if metadata.get('input_spec') is not None:
|
||||
revived_obj.input_spec = input_spec.InputSpec.from_config(
|
||||
metadata['input_spec'])
|
||||
if metadata.get('activity_regularizer') is not None:
|
||||
revived_obj.activity_regularizer = regularizers.deserialize(
|
||||
metadata['activity_regularizer'])
|
||||
|
||||
# Store attributes revived from SerializedAttributes in a un-tracked
|
||||
# dictionary. The attributes are the ones listed in CommonEndpoints or
|
||||
# "keras_api" for keras-specific attributes.
|
||||
revived_obj._serialized_attributes = {}
|
||||
# pylint:enable=protected-access
|
||||
|
||||
return revived_obj
|
||||
|
||||
def _revive_setter(self, name, value):
|
||||
"""Reattaches attributes from the SavedModel to the newly revived object."""
|
||||
if (name in CommonEndpoints.all_functions or
|
||||
name in CommonEndpoints.all_checkpointable_objects or
|
||||
name == _KERAS_ATTR):
|
||||
self._serialized_attributes[name] = value
|
||||
else:
|
||||
setattr(self, name, value)
|
||||
|
||||
@property
|
||||
def keras_api(self):
|
||||
return self._serialized_attributes[_KERAS_ATTR]
|
||||
|
||||
def get_config(self):
|
||||
if hasattr(self, '_config'):
|
||||
return self._config
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
"""Calls the revived layer and add conditional losses."""
|
||||
call_fn = _use_wrapped_call(
|
||||
self, self.keras_api.call_and_return_conditional_losses)
|
||||
return call_fn(inputs, *args, **kwargs)
|
||||
|
||||
|
||||
class RevivedNetwork(RevivedLayer):
|
||||
"""Keras network of layers loaded from a SavedModel."""
|
||||
|
||||
@classmethod
|
||||
def _init_from_metadata(cls, metadata):
|
||||
"""Create revived network from metadata stored in the SavedModel proto."""
|
||||
# TODO(kathywu): Refactor logic here so that RevivedNetwork uses the
|
||||
revived_obj = cls(name=metadata['name'])
|
||||
|
||||
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
||||
# pylint:disable=protected-access
|
||||
if metadata.get('dtype') is not None:
|
||||
revived_obj._dtype = metadata['dtype']
|
||||
revived_obj.trainable = metadata['trainable']
|
||||
|
||||
revived_obj._expects_training_arg = metadata['expects_training_arg']
|
||||
if metadata.get('config') is not None:
|
||||
revived_obj._config = metadata['config']
|
||||
|
||||
if metadata.get('activity_regularizer') is not None:
|
||||
revived_obj.activity_regularizer = regularizers.deserialize(
|
||||
metadata['activity_regularizer'])
|
||||
|
||||
# Store attributes revived from SerializedAttributes in a un-tracked
|
||||
# dictionary. The attributes are the ones listed in CommonEndpoints or
|
||||
# "keras_api" for keras-specific attributes.
|
||||
revived_obj._serialized_attributes = {}
|
||||
# pylint:enable=protected-access
|
||||
|
||||
return revived_obj
|
||||
|
||||
|
||||
class RevivedModel(RevivedNetwork):
|
||||
"""Keras model loaded from a SavedModel."""
|
||||
|
||||
@classmethod
|
||||
def _init_from_metadata(cls, metadata):
|
||||
"""Create revived model from metadata stored in the SavedModel proto."""
|
||||
revived_obj = super(RevivedModel, cls)._init_from_metadata(metadata)
|
||||
|
||||
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
||||
if 'training_config' in metadata:
|
||||
revived_obj._training_config = metadata['training_config'] # pylint:disable=protected-access
|
||||
|
||||
return revived_obj
|
||||
|
@ -28,19 +28,28 @@ from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import training as model_lib
|
||||
from tensorflow.python.keras.optimizer_v2 import adadelta
|
||||
from tensorflow.python.keras.saving import saved_model as keras_saved_model
|
||||
from tensorflow.python.keras.utils import mode_keys
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import load as tf_load
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import model_utils
|
||||
from tensorflow.python.saved_model import save as tf_save
|
||||
from tensorflow.python.training import training as training_module
|
||||
|
||||
|
||||
@ -186,7 +195,7 @@ class TestModelSavingandLoading(test.TestCase):
|
||||
# For now, saving subclassed model should raise an error. It should be
|
||||
# avoided later with loading from SavedModel.pb.
|
||||
|
||||
class SubclassedModel(training.Model):
|
||||
class SubclassedModel(model_lib.Model):
|
||||
|
||||
def __init__(self):
|
||||
super(SubclassedModel, self).__init__()
|
||||
@ -205,10 +214,15 @@ class TestModelSavingandLoading(test.TestCase):
|
||||
|
||||
class LayerWithLearningPhase(keras.engine.base_layer.Layer):
|
||||
|
||||
def call(self, x):
|
||||
phase = keras.backend.learning_phase()
|
||||
def build(self, input_shape):
|
||||
self.input_spec = keras.layers.InputSpec(shape=[None] * len(input_shape))
|
||||
self.built = True
|
||||
|
||||
def call(self, x, training=None):
|
||||
if training is None:
|
||||
training = keras.backend.learning_phase()
|
||||
output = tf_utils.smart_cond(
|
||||
phase, lambda: x * 0, lambda: array_ops.identity(x))
|
||||
training, lambda: x * 0, lambda: array_ops.identity(x))
|
||||
if not context.executing_eagerly():
|
||||
output._uses_learning_phase = True # pylint: disable=protected-access
|
||||
return output
|
||||
@ -538,5 +552,171 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose(ref_predict, predictions, atol=1e-05)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
|
||||
def _save_model_dir(self, dirname='saved_model'):
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
|
||||
return os.path.join(temp_dir, dirname)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_model_save_and_load(self):
|
||||
input_arr = np.random.random((1, 3)).astype(np.float32)
|
||||
target_arr = np.random.random((1, 4)).astype(np.float32)
|
||||
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
model.layers[-1].activity_regularizer = regularizers.get('l2')
|
||||
model.activity_regularizer = regularizers.get('l2')
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer='rmsprop')
|
||||
model.train_on_batch(input_arr, target_arr)
|
||||
|
||||
def callable_loss():
|
||||
return math_ops.reduce_sum(model.weights[0])
|
||||
model.add_loss(callable_loss)
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
|
||||
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
self.assertAllClose(self.evaluate(model.weights),
|
||||
self.evaluate(loaded.weights))
|
||||
|
||||
input_arr = constant_op.constant(
|
||||
np.random.random((1, 3)).astype(np.float32))
|
||||
self.assertAllClose(self.evaluate(model(input_arr)),
|
||||
self.evaluate(loaded(input_arr)))
|
||||
# Validate losses. The order of conditional losses may change between the
|
||||
# model and loaded model, so sort the losses first.
|
||||
if context.executing_eagerly():
|
||||
self.assertAllClose(sorted(self.evaluate(model.losses)),
|
||||
sorted(self.evaluate(loaded.losses)))
|
||||
else:
|
||||
self.assertAllClose(self.evaluate(model.get_losses_for(None)),
|
||||
self.evaluate(loaded.get_losses_for(None)))
|
||||
self.assertAllClose(
|
||||
sorted(self.evaluate(model.get_losses_for(input_arr))),
|
||||
sorted(self.evaluate(loaded.get_losses_for(input_arr))))
|
||||
|
||||
def test_trainable_weights(self):
|
||||
layer = keras.layers.Dense(4, name='custom_layer')
|
||||
layer.build([3,])
|
||||
layer.add_weight(
|
||||
'extra_weight', shape=[],
|
||||
initializer=init_ops.constant_initializer(11),
|
||||
trainable=True)
|
||||
layer.add_weight(
|
||||
'extra_weight_2', shape=[],
|
||||
initializer=init_ops.constant_initializer(12),
|
||||
trainable=False)
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
self.evaluate(variables.variables_initializer(layer.variables))
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
|
||||
equal_attrs = ['name', '_expects_training_arg', 'trainable']
|
||||
for attr in equal_attrs:
|
||||
self.assertEqual(getattr(layer, attr), getattr(loaded, attr))
|
||||
|
||||
all_close = ['weights', 'trainable_weights', 'non_trainable_weights']
|
||||
for attr in all_close:
|
||||
self.assertAllClose(self.evaluate(getattr(layer, attr)),
|
||||
self.evaluate(getattr(loaded, attr)))
|
||||
|
||||
def test_maintains_losses(self):
|
||||
"""Tests that the layer losses do not change before and after export."""
|
||||
|
||||
class LayerWithLoss(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
self.add_loss(math_ops.reduce_sum(inputs), inputs)
|
||||
return inputs
|
||||
|
||||
model = keras.models.Sequential([LayerWithLoss()])
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer='rmsprop')
|
||||
input_arr = np.random.random((1, 3)).astype(np.float32)
|
||||
target_arr = np.random.random((1, 3)).astype(np.float32)
|
||||
|
||||
# Test that symbolic losses are maintained (train_on_batch saves symbolic
|
||||
# losses.)
|
||||
model.train_on_batch(input_arr, target_arr)
|
||||
previous_losses = model.losses[:]
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
self.assertAllEqual(previous_losses, model.losses)
|
||||
|
||||
if context.executing_eagerly():
|
||||
# Test that eager losses are maintained.
|
||||
model(input_arr) # Calls model eagerly, creating eager losses.
|
||||
previous_losses = model.losses[:]
|
||||
tf_save.save(model, saved_model_dir)
|
||||
self.assertAllEqual(previous_losses, model.losses)
|
||||
|
||||
def test_layer_with_learning_phase(self):
|
||||
layer = LayerWithLearningPhase()
|
||||
layer.build([None, None])
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
|
||||
input_arr = array_ops.ones((4, 3))
|
||||
|
||||
# Run the layer, and use the keras backend learing phase
|
||||
keras.backend.set_learning_phase(0)
|
||||
self.assertAllEqual(input_arr, loaded(input_arr))
|
||||
keras.backend.set_learning_phase(1)
|
||||
self.assertAllEqual(array_ops.zeros((4, 3)), loaded(input_arr))
|
||||
|
||||
# Run the layer while explicitly setting the training argument
|
||||
self.assertAllEqual(
|
||||
input_arr, loaded(input_arr, training=constant_op.constant(False)))
|
||||
self.assertAllEqual(
|
||||
array_ops.zeros((4, 3)),
|
||||
loaded(input_arr, training=constant_op.constant(True)))
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_standard_loader(self):
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
model.activity_regularizer = regularizers.get('l2')
|
||||
def eager_loss():
|
||||
return math_ops.reduce_sum(model.weights[0])
|
||||
model.add_loss(eager_loss)
|
||||
|
||||
# Call predict to ensure that all layers are built and inputs are set.
|
||||
model.predict(np.random.random((1, 3)))
|
||||
saved_model_dir = self._save_model_dir()
|
||||
|
||||
tf_save.save(model, saved_model_dir)
|
||||
|
||||
loaded = tf_load.load(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
all_close = ['variables', 'trainable_variables',
|
||||
'non_trainable_variables']
|
||||
for attr in all_close:
|
||||
self.assertAllClose(self.evaluate(getattr(model, attr)),
|
||||
self.evaluate(getattr(loaded.keras_api, attr)))
|
||||
self.assertLen(loaded.regularization_losses, 1)
|
||||
expected_layers = len(model.layers)
|
||||
if testing_utils.get_model_type() == 'sequential':
|
||||
# The autogenerated Input layer is hidden in the model.layers list,
|
||||
# but included in the loaded sub-layers.
|
||||
expected_layers += 1
|
||||
self.assertEqual(expected_layers, len(loaded.keras_api.layers))
|
||||
input_arr = array_ops.ones((4, 3))
|
||||
training_bool = constant_op.constant(False)
|
||||
|
||||
if model._expects_training_arg:
|
||||
call_args = [input_arr, training_bool]
|
||||
else:
|
||||
call_args = [input_arr]
|
||||
self.assertAllClose(self.evaluate(model(input_arr)),
|
||||
self.evaluate(loaded(*call_args)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=protected-access
|
||||
"""Utils related to keras model saving."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -22,6 +21,9 @@ import collections
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -43,7 +45,53 @@ def extract_model_metrics(model):
|
||||
# TODO(psv/kathywu): use this implementation in model to estimator flow.
|
||||
# We are not using model.metrics here because we want to exclude the metrics
|
||||
# added using `add_metric` API.
|
||||
return {m.name: m for m in model._compile_metric_functions}
|
||||
return {m.name: m for m in model._compile_metric_functions} # pylint: disable=protected-access
|
||||
|
||||
|
||||
def model_input_signature(model):
|
||||
"""Inspect model to get its input signature.
|
||||
|
||||
The model's input signature is a list with a single (possibly-nested) object.
|
||||
This is due to the Keras-enforced restriction that tensor inputs must be
|
||||
passed in as the first argument.
|
||||
|
||||
For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
|
||||
will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]
|
||||
|
||||
Args:
|
||||
model: Keras Model object
|
||||
|
||||
Returns:
|
||||
A list containing either a single TensorSpec or an object with nested
|
||||
TensorSpecs.
|
||||
"""
|
||||
try:
|
||||
inputs = model.inputs
|
||||
input_names = model.input_names
|
||||
except AttributeError:
|
||||
return None
|
||||
flat_inputs = nest.flatten(inputs)
|
||||
flat_input_names = nest.flatten(input_names)
|
||||
flat_input_specs = []
|
||||
for input_tensor, input_name in zip(flat_inputs, flat_input_names):
|
||||
# If the user has not explicitly provided the input_signature, we
|
||||
# create it from the inputs. We make sure to set the first dimension
|
||||
# (batch) to None here, as in serving or retraining, batch should not
|
||||
# be fixed. See b/132783590 for context.
|
||||
input_shape = [None] + input_tensor.shape[1:].as_list()
|
||||
flat_input_specs.append(tensor_spec.TensorSpec(
|
||||
shape=input_shape, dtype=input_tensor.dtype,
|
||||
name=input_name))
|
||||
input_specs = nest.pack_sequence_as(structure=inputs,
|
||||
flat_sequence=flat_input_specs)
|
||||
|
||||
# Return a list with a single element as the model's input signature.
|
||||
if isinstance(input_specs, collections.Sequence) and len(input_specs) == 1:
|
||||
# Note that the isinstance check filters out single-element dictionaries,
|
||||
# which should also be wrapped as a single-element list.
|
||||
return input_specs
|
||||
else:
|
||||
return [input_specs]
|
||||
|
||||
|
||||
def trace_model_call(model, input_signature=None):
|
||||
@ -65,37 +113,14 @@ def trace_model_call(model, input_signature=None):
|
||||
input_signature = model.call.input_signature
|
||||
|
||||
if input_signature is None:
|
||||
try:
|
||||
inputs = model.inputs
|
||||
input_names = model.input_names
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
'Model {} cannot be saved because the input shapes have not been '
|
||||
'set. Usually, input shapes are automatically determined from calling'
|
||||
' .fit() or .predict(). To manually set the shapes, call '
|
||||
'model._set_inputs(inputs).'.format(model))
|
||||
flat_inputs = nest.flatten(inputs)
|
||||
flat_input_names = nest.flatten(input_names)
|
||||
flat_input_specs = []
|
||||
for input_tensor, input_name in zip(flat_inputs, flat_input_names):
|
||||
# If the user has not explicitly provided the input_signature, we
|
||||
# create it from the inputs. We make sure to set the first dimension
|
||||
# (batch) to None here, as in serving or retraining, batch should not
|
||||
# be fixed. See b/132783590 for context.
|
||||
input_shape = [None] + input_tensor.shape[1:].as_list()
|
||||
flat_input_specs.append(tensor_spec.TensorSpec(
|
||||
shape=input_shape, dtype=input_tensor.dtype,
|
||||
name=input_name))
|
||||
input_specs = nest.pack_sequence_as(structure=inputs,
|
||||
flat_sequence=flat_input_specs)
|
||||
# The input signature of the call function is a list with one element, since
|
||||
# all tensor inputs must be passed in as the first argument. Single-element
|
||||
# dictionaries and other non-sequence types must also be wrapped.
|
||||
if (len(input_specs) > 1
|
||||
or not isinstance(input_specs, collections.Sequence)):
|
||||
input_signature = [input_specs]
|
||||
else:
|
||||
input_signature = input_specs
|
||||
input_signature = model_input_signature(model)
|
||||
|
||||
if input_signature is None:
|
||||
raise ValueError(
|
||||
'Model {} cannot be saved because the input shapes have not been '
|
||||
'set. Usually, input shapes are automatically determined from calling'
|
||||
' .fit() or .predict(). To manually set the shapes, call '
|
||||
'model._set_inputs(inputs).'.format(model))
|
||||
|
||||
# TODO(mdan): Should the model's call be autographed by default?
|
||||
@def_function.function(input_signature=input_signature, autograph=False)
|
||||
@ -113,3 +138,55 @@ def trace_model_call(model, input_signature=None):
|
||||
return {name: output for name, output in zip(output_names, outputs_list)}
|
||||
|
||||
return _wrapped_model
|
||||
|
||||
|
||||
def model_metadata(model, include_optimizer=True, require_config=True):
|
||||
"""Returns a dictionary containing the model metadata."""
|
||||
from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2 # pylint: disable=g-import-not-at-top
|
||||
|
||||
model_config = {'class_name': model.__class__.__name__}
|
||||
try:
|
||||
model_config['config'] = model.get_config()
|
||||
except NotImplementedError as e:
|
||||
if require_config:
|
||||
raise e
|
||||
|
||||
metadata = dict(
|
||||
keras_version=str(keras_version),
|
||||
backend=K.backend(),
|
||||
model_config=model_config)
|
||||
if model.optimizer and include_optimizer:
|
||||
if isinstance(model.optimizer, optimizers.TFOptimizer):
|
||||
logging.warning(
|
||||
'TensorFlow optimizers do not '
|
||||
'make it possible to access '
|
||||
'optimizer attributes or optimizer state '
|
||||
'after instantiation. '
|
||||
'As a result, we cannot save the optimizer '
|
||||
'as part of the model save file. '
|
||||
'You will have to compile your model again after loading it. '
|
||||
'Prefer using a Keras optimizer instead '
|
||||
'(see keras.io/optimizers).')
|
||||
else:
|
||||
metadata['training_config'] = {
|
||||
'loss': model.loss,
|
||||
# pylint: disable=protected-access
|
||||
'metrics': model._compile_metrics,
|
||||
'weighted_metrics': model._compile_weighted_metrics,
|
||||
# pylint: enable=protected-access
|
||||
'sample_weight_mode': model.sample_weight_mode,
|
||||
'loss_weights': model.loss_weights,
|
||||
}
|
||||
if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
|
||||
raise NotImplementedError(
|
||||
'As of now, Optimizers loaded from SavedModel cannot be saved. '
|
||||
'If you\'re calling `model.save` or `tf.keras.models.save_model`, '
|
||||
'please set the `include_optimizer` option to `False`. For '
|
||||
'`tf.saved_model.save`, delete the optimizer from the model.')
|
||||
else:
|
||||
optimizer_config = {
|
||||
'class_name': model.optimizer.__class__.__name__,
|
||||
'config': model.optimizer.get_config()}
|
||||
metadata['training_config']['optimizer_config'] = optimizer_config
|
||||
return metadata
|
||||
|
@ -99,7 +99,7 @@ class _WrapperFunction(function.ConcreteFunction):
|
||||
return super(_WrapperFunction, self)._call_flat(args, captured_inputs)
|
||||
|
||||
|
||||
class _Loader(object):
|
||||
class Loader(object):
|
||||
"""Helper class to load an object-based SavedModel."""
|
||||
|
||||
def __init__(self, object_graph_proto, saved_model_proto, export_dir):
|
||||
@ -249,7 +249,7 @@ class _Loader(object):
|
||||
# Note: if an object has an attribute `__call__` add a class method
|
||||
# that allows `obj()` syntax to work. This is done per-instance to
|
||||
# allow `callable` to be used to find out if an object is callable.
|
||||
if reference.local_name == "__call__":
|
||||
if reference.local_name == "__call__" and not callable(obj):
|
||||
setattr(type(obj), "__call__", _call_attribute)
|
||||
|
||||
def _restore_checkpoint(self):
|
||||
@ -308,16 +308,20 @@ class _Loader(object):
|
||||
"""Instantiates a SavedUserObject."""
|
||||
looked_up = revived_types.deserialize(proto)
|
||||
if looked_up is None:
|
||||
# Note: each user object has its own class. This allows to make each one
|
||||
# individually callable by adding a `__call__` method to the classes of
|
||||
# the objects instances that have a `__call__` property.
|
||||
|
||||
class _UserObject(tracking.AutoTrackable):
|
||||
pass
|
||||
|
||||
return _UserObject(), setattr
|
||||
return self._recreate_base_user_object(proto)
|
||||
return looked_up
|
||||
|
||||
def _recreate_base_user_object(self, proto):
|
||||
del proto
|
||||
# Note: each user object has its own class. This allows to make each one
|
||||
# individually callable by adding a `__call__` method to the classes of
|
||||
# the objects instances that have a `__call__` property.
|
||||
|
||||
class _UserObject(tracking.AutoTrackable):
|
||||
pass
|
||||
|
||||
return _UserObject(), setattr
|
||||
|
||||
def _recreate_asset(self, proto):
|
||||
filename = os.path.join(
|
||||
saved_model_utils.get_assets_dir(self._export_dir),
|
||||
@ -386,7 +390,7 @@ class _RestoredResource(tracking.TrackableResource):
|
||||
def _initialize(self):
|
||||
raise RuntimeError()
|
||||
|
||||
def _list_functions_for_serialization(self):
|
||||
def _list_functions_for_serialization(self, unused_serialization_cache):
|
||||
# Overwrite this method to avoid the implementation of
|
||||
# base class to re-wrap the polymorphic functions into
|
||||
# another layer of `tf.function`.
|
||||
@ -426,6 +430,22 @@ def load(export_dir, tags=None):
|
||||
assert 6. == imported.f(x=tf.constant(2.)).numpy()
|
||||
```
|
||||
|
||||
_Loading Keras models_
|
||||
|
||||
Keras models are trackable, so they can be saved to SavedModel. The object
|
||||
returned by `tf.saved_model.load` is not a Keras object (i.e. doesn't have
|
||||
`.fit`, `.predict`, etc. methods). A few attributes and functions are still
|
||||
available: `.variables`, `.trainable_variables` and `.__call__`.
|
||||
|
||||
```python
|
||||
model = tf.keras.Model(...)
|
||||
tf.saved_model.save(model, path)
|
||||
imported = tf.saved_model.load(path)
|
||||
outputs = imported(inputs)
|
||||
```
|
||||
|
||||
Use `tf.keras.models.load_model` to restore the Keras model.
|
||||
|
||||
_Importing SavedModels from TensorFlow 1.x_
|
||||
|
||||
SavedModels from `tf.estimator.Estimator` or 1.x SavedModel APIs have a flat
|
||||
@ -462,6 +482,11 @@ def load(export_dir, tags=None):
|
||||
Raises:
|
||||
ValueError: If `tags` don't match a MetaGraph in the SavedModel.
|
||||
"""
|
||||
return load_internal(export_dir, tags)
|
||||
|
||||
|
||||
def load_internal(export_dir, tags=None, loader_cls=Loader):
|
||||
"""Loader implementation."""
|
||||
if tags is not None and not isinstance(tags, set):
|
||||
# Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
|
||||
# sequences for nest.flatten, so we put those through as-is.
|
||||
@ -479,9 +504,9 @@ def load(export_dir, tags=None):
|
||||
.format(export_dir, meta_graph_def.meta_info_def.tags, tags))
|
||||
object_graph_proto = meta_graph_def.object_graph_def
|
||||
with ops.init_scope():
|
||||
loader = _Loader(object_graph_proto,
|
||||
saved_model_proto,
|
||||
export_dir)
|
||||
loader = loader_cls(object_graph_proto,
|
||||
saved_model_proto,
|
||||
export_dir)
|
||||
root = loader.get(0)
|
||||
root.tensorflow_version = meta_graph_def.meta_info_def.tensorflow_version
|
||||
root.tensorflow_git_version = (
|
||||
|
@ -1711,5 +1711,26 @@ class SingleCycleTests(test.TestCase, parameterized.TestCase):
|
||||
f(x=constant_op.constant([[-1.]]))["output_0"].numpy())
|
||||
|
||||
|
||||
def test_object_with_extra_dependencies(self):
|
||||
|
||||
class Extra(tracking.AutoTrackable):
|
||||
|
||||
def _list_extra_dependencies_for_serialization(self, cache):
|
||||
if self not in cache:
|
||||
cache[self] = {"a": variables.Variable(5.)}
|
||||
return cache[self]
|
||||
root = Extra()
|
||||
path = tempfile.mkdtemp(prefix=self.get_temp_dir())
|
||||
save.save(root, path)
|
||||
imported = load.load(path)
|
||||
self.assertEqual(5, self.evaluate(imported.a))
|
||||
|
||||
root.a = variables.Variable(3.)
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
"object has an attribute named a, which is reserved."):
|
||||
save.save(root, path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -165,3 +165,7 @@ def deserialize(proto):
|
||||
if type_registration.should_load(proto):
|
||||
return (type_registration.from_proto(proto), type_registration.setter)
|
||||
return None
|
||||
|
||||
|
||||
def registered_identifiers():
|
||||
return _REVIVED_TYPE_REGISTRY.keys()
|
||||
|
@ -91,6 +91,10 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
|
||||
# Object -> (name -> dep)
|
||||
self._extra_dependencies = object_identity.ObjectIdentityDictionary()
|
||||
self._functions = object_identity.ObjectIdentityDictionary()
|
||||
# Cache shared between objects in the same object graph. This is passed to
|
||||
# each trackable object's `_list_extra_dependencies_for_serialization` and
|
||||
# `_list_functions_for_serialization` function.
|
||||
self._serialization_cache = object_identity.ObjectIdentityDictionary()
|
||||
|
||||
def add_object(self, parent_node, name_in_parent, subgraph_root):
|
||||
"""Attach an object to `parent_node`, overriding any existing dependency."""
|
||||
@ -99,11 +103,23 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
|
||||
|
||||
def list_dependencies(self, obj):
|
||||
"""Overrides a parent method to include `add_object` objects."""
|
||||
extra_dependencies = self._extra_dependencies.get(obj, {})
|
||||
extra_dependencies = self.list_extra_dependencies(obj)
|
||||
extra_dependencies.update(self._extra_dependencies.get(obj, {}))
|
||||
|
||||
used_names = set()
|
||||
for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj):
|
||||
used_names.add(name)
|
||||
if name in extra_dependencies:
|
||||
# Extra dependencies (except for `.signatures`, which is always added
|
||||
# when saving) should not have naming conflicts with dependencies
|
||||
# defined by the user.
|
||||
if name != signature_serialization.SIGNATURE_ATTRIBUTE_NAME:
|
||||
raise ValueError(
|
||||
"Error when exporting object {} of with identifier={}. The object"
|
||||
" has an attribute named {}, which is reserved. List of all "
|
||||
"reserved attributes: {}".format(
|
||||
obj, obj._object_identifier, # pylint: disable=protected-access
|
||||
name, extra_dependencies.keys()))
|
||||
yield base.TrackableReference(name, extra_dependencies[name])
|
||||
else:
|
||||
yield base.TrackableReference(name, dep)
|
||||
@ -112,10 +128,15 @@ class _AugmentedGraphView(graph_view.ObjectGraphView):
|
||||
continue
|
||||
yield base.TrackableReference(name, dep)
|
||||
|
||||
def list_extra_dependencies(self, obj):
|
||||
return obj._list_extra_dependencies_for_serialization( # pylint: disable=protected-access
|
||||
self._serialization_cache)
|
||||
|
||||
def list_functions(self, obj):
|
||||
obj_functions = self._functions.get(obj, None)
|
||||
if obj_functions is None:
|
||||
obj_functions = obj._list_functions_for_serialization() # pylint: disable=protected-access
|
||||
obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
|
||||
self._serialization_cache)
|
||||
self._functions[obj] = obj_functions
|
||||
return obj_functions
|
||||
|
||||
@ -614,10 +635,13 @@ def _write_object_proto(obj, proto, asset_file_def_index):
|
||||
registered_type_proto = revived_types.serialize(obj)
|
||||
if registered_type_proto is None:
|
||||
# Fallback for types with no matching registration
|
||||
# pylint:disable=protected-access
|
||||
registered_type_proto = saved_object_graph_pb2.SavedUserObject(
|
||||
identifier="_generic_user_object",
|
||||
identifier=obj._object_identifier,
|
||||
version=versions_pb2.VersionDef(
|
||||
producer=1, min_consumer=1, bad_consumers=[]))
|
||||
producer=1, min_consumer=1, bad_consumers=[]),
|
||||
metadata=obj._tracking_metadata)
|
||||
# pylint:enable=protected-access
|
||||
proto.user_object.CopyFrom(registered_type_proto)
|
||||
|
||||
|
||||
|
@ -204,7 +204,7 @@ class _SignatureMap(collections.Mapping, base.Trackable):
|
||||
def __repr__(self):
|
||||
return "_SignatureMap({})".format(self._signatures)
|
||||
|
||||
def _list_functions_for_serialization(self):
|
||||
def _list_functions_for_serialization(self, unused_serialization_cache):
|
||||
return {
|
||||
key: value for key, value in self.items()
|
||||
if isinstance(value, (def_function.Function, defun.ConcreteFunction))
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.ops import gen_io_ops as io_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training.saving import saveable_object
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
||||
# Key where the object graph proto is saved in a TensorBundle
|
||||
@ -463,6 +464,38 @@ def no_automatic_dependency_tracking(method):
|
||||
target=method, decorator_func=_method_wrapper)
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def no_automatic_dependency_tracking_scope(obj):
|
||||
"""A context that disables automatic dependency tracking when assigning attrs.
|
||||
|
||||
Objects that inherit from Autotrackable automatically creates dependencies
|
||||
to trackable objects through attribute assignments, and wraps data structures
|
||||
(lists or dicts) with trackable classes. This scope may be used to temporarily
|
||||
disable this behavior. This works similar to the decorator
|
||||
`no_automatic_dependency_tracking`.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
model = tf.keras.Model()
|
||||
model.arr1 = [] # Creates a ListWrapper object
|
||||
with no_automatic_dependency_tracking_scope(model):
|
||||
model.arr2 = [] # Creates a regular, untracked python list
|
||||
```
|
||||
|
||||
Args:
|
||||
obj: A trackable object.
|
||||
|
||||
Yields:
|
||||
a scope in which the object doesn't track dependencies.
|
||||
"""
|
||||
previous_value = getattr(obj, "_setattr_tracking", True)
|
||||
obj._setattr_tracking = False # pylint: disable=protected-access
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
obj._setattr_tracking = previous_value # pylint: disable=protected-access
|
||||
|
||||
|
||||
class Trackable(object):
|
||||
"""Base class for `Trackable` objects without automatic dependencies.
|
||||
|
||||
@ -548,6 +581,23 @@ class Trackable(object):
|
||||
# building.
|
||||
self._self_name_based_restores = set()
|
||||
|
||||
@property
|
||||
def _object_identifier(self):
|
||||
"""String used to identify this object in a SavedModel.
|
||||
|
||||
Generally, the object identifier is constant across objects of the same
|
||||
class, while the metadata field is used for instance-specific data.
|
||||
|
||||
Returns:
|
||||
String object identifier.
|
||||
"""
|
||||
return "_generic_user_object"
|
||||
|
||||
@property
|
||||
def _tracking_metadata(self):
|
||||
"""String containing object metadata, which is saved to the SavedModel."""
|
||||
return ""
|
||||
|
||||
def _no_dependency(self, value):
|
||||
"""If automatic dependency tracking is enabled, ignores `value`."""
|
||||
return value
|
||||
@ -881,15 +931,50 @@ class Trackable(object):
|
||||
"""
|
||||
return {}
|
||||
|
||||
def _list_functions_for_serialization(self):
|
||||
def _list_extra_dependencies_for_serialization(self, serialization_cache):
|
||||
"""Lists extra dependencies to serialize.
|
||||
|
||||
Internal sub-classes can override this method to return extra dependencies
|
||||
that should be saved with the object during SavedModel serialization. For
|
||||
example, this is used to save `trainable_variables` in Keras models. The
|
||||
python property `trainable_variables` contains logic to iterate through the
|
||||
weights from the model and its sublayers. The serialized Keras model saves
|
||||
`trainable_weights` as a trackable list of variables.
|
||||
|
||||
PLEASE NOTE when overriding this method:
|
||||
1. This function may only generate new trackable objects the first time it
|
||||
is called.
|
||||
2. The returned dictionary must not have naming conflicts with
|
||||
dependencies tracked by the root. In other words, if the root is
|
||||
tracking `object_1` with name 'x', and this functions returns
|
||||
`{'x': object_2}`, an error is raised when saving.
|
||||
|
||||
Args:
|
||||
serialization_cache: A dictionary shared between all objects in the same
|
||||
object graph. This object is passed to both
|
||||
`_list_extra_dependencies_for_serialization` and
|
||||
`_list_functions_for_serialization`.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping attribute names to trackable objects.
|
||||
"""
|
||||
del serialization_cache
|
||||
return dict()
|
||||
|
||||
def _list_functions_for_serialization(self, serialization_cache):
|
||||
"""Lists the functions of this trackable to serialize.
|
||||
|
||||
Internal sub-classes can override this with specific logic. E.g.
|
||||
`AutoTrackable` provides an implementation that returns the `attr`
|
||||
that return functions.
|
||||
|
||||
Args:
|
||||
serialization_cache: Dictionary passed to all objects in the same object
|
||||
graph during serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping attribute names to `Function` or
|
||||
`ConcreteFunction`.
|
||||
"""
|
||||
del serialization_cache
|
||||
return dict()
|
||||
|
@ -76,7 +76,7 @@ def _wrap_or_unwrap(value):
|
||||
elif type(value) == collections.OrderedDict:
|
||||
return _DictWrapper(value)
|
||||
elif type(value) == list:
|
||||
return _ListWrapper(value)
|
||||
return ListWrapper(value)
|
||||
else:
|
||||
return value
|
||||
# pylint: enable=unidiomatic-typecheck
|
||||
@ -371,9 +371,9 @@ class List(TrackableDataStructure, collections.Sequence):
|
||||
# TODO(tomhennigan) Update to collections.UserList?
|
||||
# TODO(allenl): Try switching this to wrapt.ObjectProxy again when we drop
|
||||
# Python 3.4 support (may still be tricky).
|
||||
class _ListWrapper(List, collections.MutableSequence,
|
||||
# Shadowed, but there for isinstance checks.
|
||||
list):
|
||||
class ListWrapper(List, collections.MutableSequence,
|
||||
# Shadowed, but there for isinstance checks.
|
||||
list):
|
||||
"""Wraps the built-in `list` to support restore-on-create for variables.
|
||||
|
||||
Unlike `List`, this sequence type is mutable in the same ways built-in lists
|
||||
@ -383,7 +383,7 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
refuses to save.
|
||||
|
||||
On assignment to an attribute of a Model or Trackable object, Python
|
||||
lists are replaced with _ListWrapper. Wrapping a list in a
|
||||
lists are replaced with ListWrapper. Wrapping a list in a
|
||||
`tf.contrib.checkpoint.NoDependency` object prevents this.
|
||||
"""
|
||||
|
||||
@ -393,26 +393,26 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
Args:
|
||||
wrapped_list: The initial value of the data structure. A shallow copy may
|
||||
be maintained for error checking. `wrapped_list` itself should not be
|
||||
modified directly after constructing the `_ListWrapper`, and if changes
|
||||
are detected the `_ListWrapper` will throw an exception on save.
|
||||
modified directly after constructing the `ListWrapper`, and if changes
|
||||
are detected the `ListWrapper` will throw an exception on save.
|
||||
"""
|
||||
# Monotonic flags which indicate this object would not be restored properly,
|
||||
# and therefore should throw an error on save to avoid giving the impression
|
||||
# that restoring it will work.
|
||||
self._non_append_mutation = False
|
||||
self._external_modification = False
|
||||
super(_ListWrapper, self).__init__(wrapped_list)
|
||||
super(ListWrapper, self).__init__(wrapped_list)
|
||||
self._last_wrapped_list_snapshot = list(self._storage)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def __copy__(self):
|
||||
copied = super(_ListWrapper, self).__copy__()
|
||||
copied = super(ListWrapper, self).__copy__()
|
||||
copied._non_append_mutation = self._non_append_mutation
|
||||
copied._external_modification = self._external_modification
|
||||
return copied
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
copied = super(_ListWrapper, self).__deepcopy__(memo)
|
||||
copied = super(ListWrapper, self).__deepcopy__(memo)
|
||||
copied._non_append_mutation = self._non_append_mutation
|
||||
copied._external_modification = self._external_modification
|
||||
return copied
|
||||
@ -463,7 +463,7 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
"wrap it in a tf.contrib.checkpoint.NoDependency object; it will be "
|
||||
"automatically un-wrapped and subsequently ignored." % (
|
||||
self, self._storage, self._last_wrapped_list_snapshot)))
|
||||
return super(_ListWrapper, self)._checkpoint_dependencies
|
||||
return super(ListWrapper, self)._checkpoint_dependencies
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._non_append_mutation = True
|
||||
@ -502,13 +502,13 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
def append(self, value):
|
||||
"""Add a new trackable value."""
|
||||
self._check_external_modification()
|
||||
super(_ListWrapper, self).append(value)
|
||||
super(ListWrapper, self).append(value)
|
||||
self._update_snapshot()
|
||||
|
||||
def extend(self, values):
|
||||
"""Add a sequence of trackable values."""
|
||||
self._check_external_modification()
|
||||
super(_ListWrapper, self).extend(values)
|
||||
super(ListWrapper, self).extend(values)
|
||||
self._update_snapshot()
|
||||
|
||||
def __imul__(self, y):
|
||||
@ -518,7 +518,7 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
return self
|
||||
|
||||
# Relies on super() calling append, which updates the snapshot.
|
||||
return super(_ListWrapper, self).__imul__(y)
|
||||
return super(ListWrapper, self).__imul__(y)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self._storage == getattr(other, "_storage", other)
|
||||
@ -561,7 +561,7 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
def _track_value(self, value, name):
|
||||
"""Allows storage of non-trackable objects."""
|
||||
try:
|
||||
value = super(_ListWrapper, self)._track_value(value=value, name=name)
|
||||
value = super(ListWrapper, self)._track_value(value=value, name=name)
|
||||
except ValueError:
|
||||
# Even if this value isn't trackable, we need to make sure
|
||||
# NoDependency objects get unwrapped.
|
||||
@ -572,7 +572,7 @@ class _ListWrapper(List, collections.MutableSequence,
|
||||
def __repr__(self):
|
||||
return "ListWrapper(%s)" % (repr(self._storage),)
|
||||
|
||||
def _list_functions_for_serialization(self):
|
||||
def _list_functions_for_serialization(self, unused_functions):
|
||||
return {
|
||||
str(key): value for key, value in enumerate(self)
|
||||
if _is_function(value)
|
||||
@ -653,9 +653,9 @@ class Mapping(TrackableDataStructure, collections.Mapping):
|
||||
class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
|
||||
"""Wraps built-in dicts to support restore-on-create for variables.
|
||||
|
||||
_DictWrapper is to Mapping as _ListWrapper is to List. Unlike Mapping,
|
||||
_DictWrapper is to Mapping as ListWrapper is to List. Unlike Mapping,
|
||||
_DictWrapper allows non-string keys and values and arbitrary mutations (delete
|
||||
keys, reassign values). Like _ListWrapper, these mutations mean that
|
||||
keys, reassign values). Like ListWrapper, these mutations mean that
|
||||
_DictWrapper will raise an exception on save.
|
||||
"""
|
||||
|
||||
@ -827,7 +827,7 @@ class _DictWrapper(TrackableDataStructure, wrapt.ObjectProxy):
|
||||
for key, value in six.iteritems(dict(*args, **kwargs)):
|
||||
self[key] = value
|
||||
|
||||
def _list_functions_for_serialization(self):
|
||||
def _list_functions_for_serialization(self, unused_serialization_cache):
|
||||
return {
|
||||
key: value for key, value in self.items()
|
||||
if _is_function(value)
|
||||
@ -860,9 +860,9 @@ def _set_list_item(list_object, index_string, value):
|
||||
|
||||
revived_types.register_revived_type(
|
||||
"trackable_list_wrapper",
|
||||
lambda obj: isinstance(obj, _ListWrapper),
|
||||
lambda obj: isinstance(obj, ListWrapper),
|
||||
versions=[revived_types.VersionedTypeRegistration(
|
||||
object_factory=lambda proto: _ListWrapper([]),
|
||||
object_factory=lambda proto: ListWrapper([]),
|
||||
version=1,
|
||||
min_producer_version=1,
|
||||
min_consumer_version=1,
|
||||
|
@ -357,14 +357,14 @@ class ListWrapperTest(test.TestCase):
|
||||
# Skip methods that aren't overridden from object.
|
||||
continue
|
||||
|
||||
if list_method == getattr(data_structures._ListWrapper, name):
|
||||
if list_method == getattr(data_structures.ListWrapper, name):
|
||||
not_overridden.append(name)
|
||||
|
||||
if not_overridden:
|
||||
self.fail("_ListWrapper does not override %s" % (not_overridden))
|
||||
self.fail("ListWrapper does not override %s" % (not_overridden))
|
||||
|
||||
def testPickle(self):
|
||||
original = data_structures._ListWrapper([1, 2])
|
||||
original = data_structures.ListWrapper([1, 2])
|
||||
serialized = pickle.dumps(original)
|
||||
del original
|
||||
deserialized = pickle.loads(serialized)
|
||||
@ -372,7 +372,7 @@ class ListWrapperTest(test.TestCase):
|
||||
|
||||
def testSameStructure(self):
|
||||
l = [1]
|
||||
nest.assert_same_structure(l, data_structures._ListWrapper(copy.copy(l)))
|
||||
nest.assert_same_structure(l, data_structures.ListWrapper(copy.copy(l)))
|
||||
|
||||
def testFunctionCaching(self):
|
||||
@def_function.function
|
||||
@ -381,87 +381,87 @@ class ListWrapperTest(test.TestCase):
|
||||
|
||||
first_trace = f.get_concrete_function([constant_op.constant(2.)])
|
||||
second_trace = f.get_concrete_function(
|
||||
data_structures._ListWrapper([constant_op.constant(3.)]))
|
||||
data_structures.ListWrapper([constant_op.constant(3.)]))
|
||||
self.assertIs(first_trace, second_trace)
|
||||
|
||||
def testListWrapperBasic(self):
|
||||
# _ListWrapper, unlike List, compares like the built-in list type (since it
|
||||
# ListWrapper, unlike List, compares like the built-in list type (since it
|
||||
# is used to automatically replace lists).
|
||||
a = tracking.AutoTrackable()
|
||||
b = tracking.AutoTrackable()
|
||||
self.assertEqual([a, a],
|
||||
[a, a])
|
||||
self.assertEqual(data_structures._ListWrapper([a, a]),
|
||||
data_structures._ListWrapper([a, a]))
|
||||
self.assertEqual(data_structures.ListWrapper([a, a]),
|
||||
data_structures.ListWrapper([a, a]))
|
||||
self.assertEqual([a, a],
|
||||
data_structures._ListWrapper([a, a]))
|
||||
self.assertEqual(data_structures._ListWrapper([a, a]),
|
||||
data_structures.ListWrapper([a, a]))
|
||||
self.assertEqual(data_structures.ListWrapper([a, a]),
|
||||
[a, a])
|
||||
self.assertNotEqual([a, a],
|
||||
[b, a])
|
||||
self.assertNotEqual(data_structures._ListWrapper([a, a]),
|
||||
data_structures._ListWrapper([b, a]))
|
||||
self.assertNotEqual(data_structures.ListWrapper([a, a]),
|
||||
data_structures.ListWrapper([b, a]))
|
||||
self.assertNotEqual([a, a],
|
||||
data_structures._ListWrapper([b, a]))
|
||||
data_structures.ListWrapper([b, a]))
|
||||
self.assertLess([a], [a, b])
|
||||
self.assertLess(data_structures._ListWrapper([a]),
|
||||
data_structures._ListWrapper([a, b]))
|
||||
self.assertLess(data_structures.ListWrapper([a]),
|
||||
data_structures.ListWrapper([a, b]))
|
||||
self.assertLessEqual([a], [a, b])
|
||||
self.assertLessEqual(data_structures._ListWrapper([a]),
|
||||
data_structures._ListWrapper([a, b]))
|
||||
self.assertLessEqual(data_structures.ListWrapper([a]),
|
||||
data_structures.ListWrapper([a, b]))
|
||||
self.assertGreater([a, b], [a])
|
||||
self.assertGreater(data_structures._ListWrapper([a, b]),
|
||||
data_structures._ListWrapper([a]))
|
||||
self.assertGreater(data_structures.ListWrapper([a, b]),
|
||||
data_structures.ListWrapper([a]))
|
||||
self.assertGreaterEqual([a, b], [a])
|
||||
self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
|
||||
data_structures._ListWrapper([a]))
|
||||
self.assertEqual([a], data_structures._ListWrapper([a]))
|
||||
self.assertGreaterEqual(data_structures.ListWrapper([a, b]),
|
||||
data_structures.ListWrapper([a]))
|
||||
self.assertEqual([a], data_structures.ListWrapper([a]))
|
||||
self.assertEqual([a], list(data_structures.List([a])))
|
||||
self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
|
||||
self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
|
||||
self.assertIsInstance(data_structures._ListWrapper([a]), list)
|
||||
self.assertEqual([a, a], data_structures.ListWrapper([a]) + [a])
|
||||
self.assertEqual([a, a], [a] + data_structures.ListWrapper([a]))
|
||||
self.assertIsInstance(data_structures.ListWrapper([a]), list)
|
||||
|
||||
def testAcceptsNonTrackableContent(self):
|
||||
l = data_structures._ListWrapper([1, 2, 3])
|
||||
l = data_structures.ListWrapper([1, 2, 3])
|
||||
self.assertEqual(l, [1, 2, 3])
|
||||
|
||||
def testWrapperChangesList(self):
|
||||
l = []
|
||||
l_wrapper = data_structures._ListWrapper(l)
|
||||
l_wrapper = data_structures.ListWrapper(l)
|
||||
l_wrapper.append(1)
|
||||
self.assertEqual([1], l)
|
||||
|
||||
def testListChangesWrapper(self):
|
||||
l = []
|
||||
l_wrapper = data_structures._ListWrapper(l)
|
||||
l_wrapper = data_structures.ListWrapper(l)
|
||||
l.append(1)
|
||||
self.assertEqual([1], l_wrapper)
|
||||
|
||||
def testLayerCollectionWithExternalMutation(self):
|
||||
l = []
|
||||
l_wrapper = data_structures._ListWrapper(l)
|
||||
l_wrapper = data_structures.ListWrapper(l)
|
||||
layer = core.Dense(1)
|
||||
l.append(layer)
|
||||
self.assertEqual([layer], l_wrapper.layers)
|
||||
|
||||
def testNotHashable(self):
|
||||
with self.assertRaises(TypeError):
|
||||
hash(data_structures._ListWrapper())
|
||||
hash(data_structures.ListWrapper())
|
||||
|
||||
def testDelItem(self):
|
||||
l = data_structures._ListWrapper([1, 2, 3, 4])
|
||||
l = data_structures.ListWrapper([1, 2, 3, 4])
|
||||
del l[0]
|
||||
self.assertEqual(l, [2, 3, 4])
|
||||
self.assertUnableToSave(l, "Unable to save .*__delitem__")
|
||||
|
||||
def testDelSlice(self):
|
||||
l = data_structures._ListWrapper([1, 2, 3, 4])
|
||||
l = data_structures.ListWrapper([1, 2, 3, 4])
|
||||
del l[2:3]
|
||||
self.assertEqual(l, [1, 2, 4])
|
||||
self.assertUnableToSave(l, "Unable to save .*__delslice__")
|
||||
|
||||
def testSetSlice_canSaveForNonTrackableItems(self):
|
||||
l = data_structures._ListWrapper([1, 2, 3, 4])
|
||||
l = data_structures.ListWrapper([1, 2, 3, 4])
|
||||
l[:] = 2, 8, 9, 0
|
||||
self.assertEqual(l, [2, 8, 9, 0])
|
||||
l._maybe_initialize_trackable() # pylint: disable=protected-access
|
||||
@ -470,30 +470,30 @@ class ListWrapperTest(test.TestCase):
|
||||
def testSetSlice_cannotSaveIfTrackableModified(self):
|
||||
v1 = resource_variable_ops.ResourceVariable(1.)
|
||||
v2 = resource_variable_ops.ResourceVariable(1.)
|
||||
l = data_structures._ListWrapper([1, 2, v1, v2])
|
||||
l = data_structures.ListWrapper([1, 2, v1, v2])
|
||||
l[:] = 2, 8, 9, v2
|
||||
self.assertEqual(l, [2, 8, 9, v2])
|
||||
self.assertUnableToSave(l, "Unable to save .*__setslice__")
|
||||
|
||||
def testSetSlice_truncate(self):
|
||||
l = data_structures._ListWrapper([1, 2, 3, 4])
|
||||
l = data_structures.ListWrapper([1, 2, 3, 4])
|
||||
l[:] = []
|
||||
self.assertEqual(l, [])
|
||||
|
||||
def testSetSlice_extend(self):
|
||||
l = data_structures._ListWrapper([1, 2, 3, 4])
|
||||
l = data_structures.ListWrapper([1, 2, 3, 4])
|
||||
l[2:] = 1, 2, 3, 4
|
||||
self.assertEqual(l, [1, 2, 1, 2, 3, 4])
|
||||
|
||||
def testIMulNegative(self):
|
||||
l = data_structures._ListWrapper([1, 2, 3, 4])
|
||||
l = data_structures.ListWrapper([1, 2, 3, 4])
|
||||
l *= -1
|
||||
self.assertEqual(l, [1, 2, 3, 4] * -1)
|
||||
self.assertUnableToSave(l, "Unable to save")
|
||||
|
||||
def testIMulPositive(self):
|
||||
v = variables.Variable(1.)
|
||||
l = data_structures._ListWrapper([1, 2, 3, 4, v])
|
||||
l = data_structures.ListWrapper([1, 2, 3, 4, v])
|
||||
self.assertEqual([("4", v)], l._checkpoint_dependencies)
|
||||
root = util.Checkpoint(l=l)
|
||||
prefix = os.path.join(self.get_temp_dir(), "ckpt")
|
||||
@ -506,7 +506,7 @@ class ListWrapperTest(test.TestCase):
|
||||
self.assertAllClose(1., v.numpy())
|
||||
|
||||
def testSort(self):
|
||||
l = data_structures._ListWrapper([1, 2, 3, 4])
|
||||
l = data_structures.ListWrapper([1, 2, 3, 4])
|
||||
l.sort()
|
||||
self.assertEqual(l, [1, 2, 3, 4])
|
||||
# Regardless of being a no-op for the input list, we still refuse to save.
|
||||
@ -708,7 +708,7 @@ class MappingTests(test.TestCase):
|
||||
# methods/properties on the object. So the options are either not to
|
||||
# subclass dict (in which case update will call normal iter methods, but the
|
||||
# object won't pass isinstance checks) or to subclass dict and keep that
|
||||
# storage updated (no shadowing all its methods like _ListWrapper).
|
||||
# storage updated (no shadowing all its methods like ListWrapper).
|
||||
new_dict.update(model.d)
|
||||
self.assertEqual({1: 3}, new_dict)
|
||||
|
||||
@ -831,14 +831,14 @@ class MappingTests(test.TestCase):
|
||||
|
||||
def testListAddOrder(self):
|
||||
self.assertEqual([1., 2.],
|
||||
data_structures._ListWrapper([1.])
|
||||
+ data_structures._ListWrapper([2.]))
|
||||
data_structures.ListWrapper([1.])
|
||||
+ data_structures.ListWrapper([2.]))
|
||||
self.assertEqual([1., 2.],
|
||||
data_structures._ListWrapper([1.])
|
||||
data_structures.ListWrapper([1.])
|
||||
+ [2.])
|
||||
self.assertEqual([1., 2.],
|
||||
[1.]
|
||||
+ data_structures._ListWrapper([2.]))
|
||||
+ data_structures.ListWrapper([2.]))
|
||||
|
||||
def testSameStructure(self):
|
||||
d = {1: "a"}
|
||||
|
@ -41,6 +41,7 @@ def has_weights(obj):
|
||||
|
||||
def filter_empty_layer_containers(layer_list):
|
||||
"""Filter out empty Layer-like containers and uniquify."""
|
||||
# TODO(b/130381733): Make this an attribute in base_layer.Layer.
|
||||
existing = object_identity.ObjectIdentitySet()
|
||||
to_visit = layer_list[::-1]
|
||||
filtered = []
|
||||
|
@ -94,7 +94,7 @@ class AutoTrackable(base.Trackable):
|
||||
"""Override to allow TrackableBase to disable dependency tracking."""
|
||||
return data_structures.NoDependency(value)
|
||||
|
||||
def _list_functions_for_serialization(self):
|
||||
def _list_functions_for_serialization(self, unused_serialization_cache):
|
||||
"""Return a dict of `Function`s of a trackable."""
|
||||
functions = {}
|
||||
for attribute_name in dir(self):
|
||||
@ -192,7 +192,7 @@ class CapturableResource(base.Trackable):
|
||||
self._resource_handle = self._create_resource()
|
||||
return self._resource_handle
|
||||
|
||||
def _list_functions_for_serialization(self):
|
||||
def _list_functions_for_serialization(self, unused_functions):
|
||||
@def_function.function(input_signature=[], autograph=False)
|
||||
def _creator():
|
||||
resource = self._create_resource()
|
||||
|
@ -6,4 +6,12 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -6,4 +6,12 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -6,4 +6,12 @@ tf_class {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ndim\', \'max_ndim\', \'min_ndim\', \'axes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_config"
|
||||
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_config"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user