Move Keras SavedModel code into separate files and directory.
No changes have been made to the implementation. PiperOrigin-RevId: 254087120
This commit is contained in:
parent
5a5fd3518f
commit
cf9bdb260f
@ -18,9 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras import saving
|
||||
from tensorflow.python.keras.saving import saved_model_experimental
|
||||
|
||||
|
||||
# TODO(kathywu): Remove all contrib callers, switch to tf.keras.
|
||||
save_keras_model = saving.export_saved_model
|
||||
load_keras_model = saving.load_from_saved_model
|
||||
save_keras_model = saved_model_experimental.export_saved_model
|
||||
load_keras_model = saved_model_experimental.load_from_saved_model
|
||||
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import saved_model_test_base as test_base
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.keras.saving import saved_model
|
||||
from tensorflow.python.keras.saving import saved_model_experimental as saved_model
|
||||
|
||||
|
||||
class KerasExperimentalSaveLoadTest(test_base.TestSavedModelBase):
|
||||
|
@ -27,7 +27,7 @@ from __future__ import print_function
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import saved_model_test_base as test_base
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.keras.saving import saved_model as keras_saved_model
|
||||
from tensorflow.python.keras.saving import saved_model_experimental as keras_saved_model
|
||||
|
||||
_DEFAULT_FUNCTION_KEY = 'serving_default'
|
||||
|
||||
|
@ -275,7 +275,12 @@ py_library(
|
||||
"saving/hdf5_format.py",
|
||||
"saving/model_config.py",
|
||||
"saving/save.py",
|
||||
"saving/saved_model.py",
|
||||
"saving/saved_model/constants.py",
|
||||
"saving/saved_model/load.py",
|
||||
"saving/saved_model/save.py",
|
||||
"saving/saved_model/serialized_attributes.py",
|
||||
"saving/saved_model/utils.py",
|
||||
"saving/saved_model_experimental.py",
|
||||
"saving/saving_utils.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
@ -1620,9 +1625,9 @@ tf_py_test(
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "saved_model_test",
|
||||
name = "saved_model_experimental_test",
|
||||
size = "medium",
|
||||
srcs = ["saving/saved_model_test.py"],
|
||||
srcs = ["saving/saved_model_experimental_test.py"],
|
||||
additional_deps = [
|
||||
":keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
@ -1636,6 +1641,22 @@ tf_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "saved_model_test",
|
||||
size = "medium",
|
||||
srcs = ["saving/saved_model/saved_model_test.py"],
|
||||
additional_deps = [
|
||||
":keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
shard_count = 4,
|
||||
tags = [
|
||||
"no_windows",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "saving_utils_test",
|
||||
size = "medium",
|
||||
|
@ -50,7 +50,7 @@ from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.engine import node as node_module
|
||||
from tensorflow.python.keras.mixed_precision.experimental import autocast_variable
|
||||
from tensorflow.python.keras.mixed_precision.experimental import policy
|
||||
from tensorflow.python.keras.saving import saved_model
|
||||
from tensorflow.python.keras.saving.saved_model import save as saved_model
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
# A module that only depends on `keras.layers` import these from here.
|
||||
|
@ -323,8 +323,8 @@ class ActivationV2IntegrationTest(keras_parameterized.TestCase):
|
||||
verbose=2)
|
||||
|
||||
output_path = os.path.join(self.get_temp_dir(), 'tf_keras_saved_model')
|
||||
keras.saving.saved_model.export_saved_model(model, output_path)
|
||||
loaded_model = keras.saving.saved_model.load_from_saved_model(output_path)
|
||||
model.save(output_path, save_format='tf')
|
||||
loaded_model = keras.models.load_model(output_path)
|
||||
self.assertEqual(model.summary(), loaded_model.summary())
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -30,6 +30,6 @@ from tensorflow.python.keras.saving.model_config import model_from_json
|
||||
from tensorflow.python.keras.saving.model_config import model_from_yaml
|
||||
from tensorflow.python.keras.saving.save import load_model
|
||||
from tensorflow.python.keras.saving.save import save_model
|
||||
from tensorflow.python.keras.saving.saved_model import export_saved_model
|
||||
from tensorflow.python.keras.saving.saved_model import load_from_saved_model
|
||||
from tensorflow.python.keras.saving.saved_model_experimental import export_saved_model
|
||||
from tensorflow.python.keras.saving.saved_model_experimental import load_from_saved_model
|
||||
from tensorflow.python.keras.saving.saving_utils import trace_model_call
|
||||
|
@ -24,7 +24,8 @@ import six
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.keras.saving import hdf5_format
|
||||
from tensorflow.python.keras.saving import saved_model
|
||||
from tensorflow.python.keras.saving.saved_model import load as saved_model_load
|
||||
from tensorflow.python.keras.saving.saved_model import save as saved_model_save
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
@ -103,7 +104,7 @@ def save_model(model,
|
||||
hdf5_format.save_model_to_hdf5(
|
||||
model, filepath, overwrite, include_optimizer)
|
||||
else:
|
||||
saved_model.save(model, filepath, overwrite, include_optimizer)
|
||||
saved_model_save.save(model, filepath, overwrite, include_optimizer)
|
||||
|
||||
|
||||
@keras_export('keras.models.load_model')
|
||||
@ -138,7 +139,7 @@ def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=
|
||||
|
||||
if isinstance(filepath, six.string_types):
|
||||
loader_impl.parse_saved_model(filepath)
|
||||
return saved_model.load_from_saved_model_v2(filepath, compile)
|
||||
return saved_model_load.load(filepath, compile)
|
||||
|
||||
raise IOError(
|
||||
'Unable to load model. Filepath is not an hdf5 file (or h5py is not '
|
||||
|
File diff suppressed because it is too large
Load Diff
24
tensorflow/python/keras/saving/saved_model/constants.py
Normal file
24
tensorflow/python/keras/saving/saved_model/constants.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Constants for Keras SavedModel serialization."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Namespace used to store all attributes added during serialization.
|
||||
# e.g. the list of layers can be accessed using `loaded.keras_api.layers`, in an
|
||||
# object loaded from `tf.saved_model.load()`.
|
||||
KERAS_ATTR = 'keras_api'
|
334
tensorflow/python/keras/saving/saved_model/load.py
Normal file
334
tensorflow/python/keras/saving/saved_model/load.py
Normal file
@ -0,0 +1,334 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras SavedModel deserialization."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.saving.saved_model import constants
|
||||
from tensorflow.python.keras.saving.saved_model import utils
|
||||
from tensorflow.python.keras.saving.saved_model.serialized_attributes import CommonEndpoints
|
||||
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
from tensorflow.python.saved_model import load as tf_load
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking.tracking import delete_tracking
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
|
||||
# To avoid circular dependencies between keras/engine and keras/saving,
|
||||
# code in keras/saving must delay imports.
|
||||
|
||||
# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
|
||||
# once the issue with copybara is fixed.
|
||||
# pylint:disable=g-inconsistent-quotes
|
||||
models_lib = LazyLoader("models_lib", globals(),
|
||||
"tensorflow.python.keras.models")
|
||||
base_layer = LazyLoader(
|
||||
"base_layer", globals(),
|
||||
"tensorflow.python.keras.engine.base_layer")
|
||||
network_lib = LazyLoader(
|
||||
"network_lib", globals(),
|
||||
"tensorflow.python.keras.engine.network")
|
||||
training_lib = LazyLoader(
|
||||
"training_lib", globals(),
|
||||
"tensorflow.python.keras.engine.training")
|
||||
# pylint:enable=g-inconsistent-quotes
|
||||
|
||||
|
||||
PUBLIC_ATTRIBUTES = CommonEndpoints.all_functions.union(
|
||||
CommonEndpoints.all_checkpointable_objects)
|
||||
PUBLIC_ATTRIBUTES.add(constants.KERAS_ATTR)
|
||||
|
||||
|
||||
def load(path, compile=True): # pylint: disable=redefined-builtin
|
||||
"""Loads Keras objects from a SavedModel.
|
||||
|
||||
Any Keras layer or model saved to the SavedModel will be loaded back
|
||||
as Keras objects. Other objects are loaded as regular trackable objects (same
|
||||
as `tf.saved_model.load`).
|
||||
|
||||
Currently, Keras saving/loading only retains the Keras object's weights,
|
||||
losses, and call function.
|
||||
|
||||
The loaded model can be re-compiled, but the original optimizer, compiled loss
|
||||
functions, and metrics are not retained. This is temporary, and `model.save`
|
||||
will soon be able to serialize compiled models.
|
||||
|
||||
Args:
|
||||
path: Path to SavedModel.
|
||||
compile: If true, compile the model after loading it.
|
||||
|
||||
Returns:
|
||||
Object loaded from SavedModel.
|
||||
"""
|
||||
# TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
|
||||
# TODO(kathywu): Add code to load from objects that contain all endpoints
|
||||
model = tf_load.load_internal(path, loader_cls=KerasObjectLoader)
|
||||
|
||||
if isinstance(model, RevivedModel) and compile:
|
||||
# TODO(kathywu): Use compiled objects from SavedModel, instead of
|
||||
# creating new objects from the training config.
|
||||
if model._training_config is not None: # pylint: disable=protected-access
|
||||
model.compile(**saving_utils.compile_args_from_training_config(
|
||||
model._training_config)) # pylint: disable=protected-access
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class KerasObjectLoader(tf_load.Loader):
|
||||
"""Loader that recreates Keras objects."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(KerasObjectLoader, self).__init__(*args, **kwargs)
|
||||
self._finalize()
|
||||
|
||||
def _finalize(self):
|
||||
# pylint: disable=protected-access
|
||||
for node in self._nodes:
|
||||
if isinstance(node, RevivedModel):
|
||||
call_fn = node.keras_api.call_and_return_conditional_losses
|
||||
if call_fn.input_signature is None:
|
||||
inputs = infer_inputs_from_restored_call_function(call_fn)
|
||||
else:
|
||||
inputs = call_fn.input_signature[0]
|
||||
if isinstance(node, RevivedSequential):
|
||||
with trackable.no_automatic_dependency_tracking_scope(node):
|
||||
node._layers = []
|
||||
for layer in node.keras_api.layers:
|
||||
node.add(layer)
|
||||
|
||||
if not node.inputs:
|
||||
# Since this revived object is technically a subclassed model (even if
|
||||
# the original model is functional/sequential), inputs should be set.
|
||||
node._set_inputs(inputs)
|
||||
if isinstance(node, RevivedLayer):
|
||||
if hasattr(node.keras_api, 'layer_regularization_losses'):
|
||||
losses = getattr(node.keras_api, 'layer_regularization_losses', [])
|
||||
else:
|
||||
# Some earlier SavedModels may not have layer_regularization_losses
|
||||
# serialized separately. Fall back to using the regularization_losses
|
||||
# list if it does not exist.
|
||||
losses = node._serialized_attributes.get('regularization_losses', [])
|
||||
for loss in losses:
|
||||
node.add_loss(loss)
|
||||
|
||||
# Use wrapped activity regularizer function if the layer's activity
|
||||
# regularizer wasn't created during initialization.
|
||||
if node.activity_regularizer is None:
|
||||
node.activity_regularizer = getattr(node.keras_api,
|
||||
'activity_regularizer_fn', None)
|
||||
|
||||
# Now that the node object has been fully loaded and restored from the,
|
||||
# checkpoint, the object no longer needs to track objects added from
|
||||
# SerializedAttributes. (Note that saving a training checkpoint still
|
||||
# functions correctly, because layers and variables are tracked
|
||||
# separately by the Layer object.)
|
||||
# TODO(kathywu): Instead of outright deleting these nodes (which would
|
||||
# make restoring from a different checkpoint tricky), mark them as extra
|
||||
# dependencies that are OK to overwrite.
|
||||
for name in PUBLIC_ATTRIBUTES:
|
||||
delete_tracking(node, name)
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def _recreate_base_user_object(self, proto):
|
||||
revived_classes = {
|
||||
'_tf_keras_layer': (RevivedLayer, base_layer.Layer),
|
||||
'_tf_keras_network': (RevivedNetwork, network_lib.Network),
|
||||
'_tf_keras_model': (RevivedModel, training_lib.Model),
|
||||
'_tf_keras_sequential': (RevivedSequential, models_lib.Sequential)
|
||||
}
|
||||
|
||||
parent_classes = revived_classes.get(proto.identifier, None)
|
||||
|
||||
if parent_classes is not None:
|
||||
parent_classes = revived_classes[proto.identifier]
|
||||
metadata = json.loads(proto.metadata)
|
||||
revived_cls = type(
|
||||
compat.as_str(metadata['class_name']),
|
||||
parent_classes,
|
||||
{'__setattr__': parent_classes[1].__setattr__})
|
||||
obj = revived_cls._init_from_metadata(metadata) # pylint: disable=protected-access
|
||||
return obj, revived_cls._revive_setter # pylint: disable=protected-access
|
||||
|
||||
return super(KerasObjectLoader, self)._recreate_base_user_object(proto)
|
||||
|
||||
|
||||
# TODO(kathywu): Centrally define keys and functions for both serialization and
|
||||
# deserialization.
|
||||
class RevivedLayer(object):
|
||||
"""Keras layer loaded from a SavedModel."""
|
||||
|
||||
@classmethod
|
||||
def _init_from_metadata(cls, metadata):
|
||||
"""Create revived layer from metadata stored in the SavedModel proto."""
|
||||
init_args = dict(
|
||||
name=metadata['name'],
|
||||
trainable=metadata['trainable'])
|
||||
if metadata.get('dtype') is not None:
|
||||
init_args['dtype'] = metadata['dtype']
|
||||
if metadata.get('batch_input_shape') is not None:
|
||||
init_args['batch_input_shape'] = metadata['batch_input_shape']
|
||||
|
||||
revived_obj = cls(**init_args)
|
||||
|
||||
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
||||
# pylint:disable=protected-access
|
||||
revived_obj._expects_training_arg = metadata['expects_training_arg']
|
||||
if metadata.get('config') is not None:
|
||||
revived_obj._config = metadata['config']
|
||||
if metadata.get('input_spec') is not None:
|
||||
revived_obj.input_spec = recursively_deserialize_keras_object(
|
||||
metadata['input_spec'],
|
||||
module_objects={'InputSpec': input_spec.InputSpec})
|
||||
if metadata.get('activity_regularizer') is not None:
|
||||
revived_obj.activity_regularizer = regularizers.deserialize(
|
||||
metadata['activity_regularizer'])
|
||||
|
||||
# Store attributes revived from SerializedAttributes in a un-tracked
|
||||
# dictionary. The attributes are the ones listed in CommonEndpoints or
|
||||
# "keras_api" for keras-specific attributes.
|
||||
revived_obj._serialized_attributes = {}
|
||||
# pylint:enable=protected-access
|
||||
|
||||
return revived_obj
|
||||
|
||||
def _revive_setter(self, name, value):
|
||||
"""Reattaches attributes from the SavedModel to the newly revived object."""
|
||||
if name in PUBLIC_ATTRIBUTES:
|
||||
if isinstance(value, trackable.Trackable):
|
||||
self._track_trackable(value, name=name)
|
||||
self._serialized_attributes[name] = value
|
||||
else:
|
||||
setattr(self, name, value)
|
||||
|
||||
@property
|
||||
def keras_api(self):
|
||||
return self._serialized_attributes[constants.KERAS_ATTR]
|
||||
|
||||
def get_config(self):
|
||||
if hasattr(self, '_config'):
|
||||
return self._config
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
"""Calls the revived layer and add conditional losses."""
|
||||
call_fn = utils.use_wrapped_call(
|
||||
self, self.keras_api.call_and_return_conditional_losses)
|
||||
return call_fn(inputs, *args, **kwargs)
|
||||
|
||||
|
||||
def recursively_deserialize_keras_object(config, module_objects=None):
|
||||
"""Deserialize Keras object from a nested structure."""
|
||||
if isinstance(config, dict):
|
||||
if 'class_name' in config:
|
||||
return deserialize_keras_object(config, module_objects=module_objects)
|
||||
else:
|
||||
return {key: recursively_deserialize_keras_object(config[key],
|
||||
module_objects)
|
||||
for key in config}
|
||||
if isinstance(config, (tuple, list)):
|
||||
return [recursively_deserialize_keras_object(x, module_objects)
|
||||
for x in config]
|
||||
else:
|
||||
raise ValueError('Unable to decode config: {}'.format(config))
|
||||
|
||||
|
||||
def infer_inputs_from_restored_call_function(fn):
|
||||
"""Returns TensorSpec of inputs from a restored call function.
|
||||
|
||||
Args:
|
||||
fn: Restored layer call function. It is assumed that the inputs are entirely
|
||||
in the first argument.
|
||||
|
||||
Returns:
|
||||
TensorSpec of call function inputs.
|
||||
"""
|
||||
def common_spec(x, y):
|
||||
return tensor_spec.TensorSpec(defun.common_shape(x.shape, y.shape),
|
||||
x.dtype, x.name)
|
||||
spec = fn.concrete_functions[0].structured_input_signature[0][0]
|
||||
for concrete in fn.concrete_functions[1:]:
|
||||
spec2 = concrete.structured_input_signature[0][0]
|
||||
spec = nest.map_structure(common_spec, spec, spec2)
|
||||
return spec
|
||||
|
||||
|
||||
class RevivedNetwork(RevivedLayer):
|
||||
"""Keras network of layers loaded from a SavedModel."""
|
||||
|
||||
@classmethod
|
||||
def _init_from_metadata(cls, metadata):
|
||||
"""Create revived network from metadata stored in the SavedModel proto."""
|
||||
# TODO(kathywu): Refactor logic here so that RevivedNetwork uses the
|
||||
revived_obj = cls(name=metadata['name'])
|
||||
|
||||
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
||||
# pylint:disable=protected-access
|
||||
if metadata.get('dtype') is not None:
|
||||
revived_obj._dtype = metadata['dtype']
|
||||
revived_obj.trainable = metadata['trainable']
|
||||
|
||||
revived_obj._expects_training_arg = metadata['expects_training_arg']
|
||||
if metadata.get('config') is not None:
|
||||
revived_obj._config = metadata['config']
|
||||
|
||||
if metadata.get('activity_regularizer') is not None:
|
||||
revived_obj.activity_regularizer = regularizers.deserialize(
|
||||
metadata['activity_regularizer'])
|
||||
|
||||
# Store attributes revived from SerializedAttributes in a un-tracked
|
||||
# dictionary. The attributes are the ones listed in CommonEndpoints or
|
||||
# "keras_api" for keras-specific attributes.
|
||||
revived_obj._serialized_attributes = {}
|
||||
# pylint:enable=protected-access
|
||||
|
||||
return revived_obj
|
||||
|
||||
|
||||
class RevivedModel(RevivedNetwork):
|
||||
"""Keras model loaded from a SavedModel."""
|
||||
|
||||
@classmethod
|
||||
def _init_from_metadata(cls, metadata):
|
||||
"""Create revived model from metadata stored in the SavedModel proto."""
|
||||
revived_obj = super(RevivedModel, cls)._init_from_metadata(metadata)
|
||||
|
||||
with trackable.no_automatic_dependency_tracking_scope(revived_obj):
|
||||
revived_obj._training_config = metadata.get('training_config') # pylint:disable=protected-access
|
||||
|
||||
return revived_obj
|
||||
|
||||
|
||||
class RevivedSequential(RevivedModel):
|
||||
"""Keras sequential model loaded from a SavedModel."""
|
||||
|
||||
@classmethod
|
||||
def _init_from_metadata(cls, metadata):
|
||||
"""Create revived Sequential model from SavedModel metadata."""
|
||||
revived_obj = super(RevivedSequential, cls)._init_from_metadata(metadata)
|
||||
return revived_obj
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
return models_lib.Sequential.call(self, *args, **kwargs)
|
573
tensorflow/python/keras/saving/saved_model/save.py
Normal file
573
tensorflow/python/keras/saving/saved_model/save.py
Normal file
@ -0,0 +1,573 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Keras SavedModel serialization."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import os
|
||||
import weakref
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_spec
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.saving.saved_model import load as keras_load
|
||||
from tensorflow.python.keras.saving.saved_model import serialized_attributes
|
||||
from tensorflow.python.keras.saving.saved_model import utils
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import save as save_lib
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking import data_structures
|
||||
from tensorflow.python.training.tracking import layer_utils as trackable_layer_utils
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
|
||||
# To avoid circular dependencies between keras/engine and keras/saving,
|
||||
# code in keras/saving must delay imports.
|
||||
|
||||
# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
|
||||
# once the issue with copybara is fixed.
|
||||
# pylint:disable=g-inconsistent-quotes
|
||||
base_layer = LazyLoader(
|
||||
"base_layer", globals(),
|
||||
"tensorflow.python.keras.engine.base_layer")
|
||||
training_lib = LazyLoader(
|
||||
"training_lib", globals(),
|
||||
"tensorflow.python.keras.engine.training")
|
||||
# pylint:enable=g-inconsistent-quotes
|
||||
|
||||
|
||||
def save(model, filepath, overwrite, include_optimizer):
|
||||
"""Saves a model as a SavedModel to the filepath.
|
||||
|
||||
Args:
|
||||
model: Keras model instance to be saved.
|
||||
filepath: String path to save the model.
|
||||
overwrite: whether to overwrite the existing filepath.
|
||||
include_optimizer: If True, save the model's optimizer state.
|
||||
|
||||
Raises:
|
||||
ValueError: if the model's inputs have not been defined.
|
||||
"""
|
||||
# If file exists and should not be overwritten.
|
||||
if not overwrite and os.path.exists(filepath):
|
||||
proceed = ask_to_proceed_with_overwrite(filepath)
|
||||
if not proceed:
|
||||
return
|
||||
|
||||
if _should_skip_serialization(model):
|
||||
saving_utils.raise_model_input_error(model)
|
||||
|
||||
if not include_optimizer:
|
||||
orig_optimizer = model.optimizer
|
||||
model.optimizer = None
|
||||
|
||||
save_lib.save(model, filepath)
|
||||
|
||||
if not include_optimizer:
|
||||
model.optimizer = orig_optimizer
|
||||
|
||||
|
||||
# Keys for the serialization cache.
|
||||
# Maps to the keras serialization dict {Layer --> SerializedAttributes object}
|
||||
_KERAS_CACHE_KEY = 'keras_serialized_attributes'
|
||||
|
||||
|
||||
def serialize_all_attributes(layer, serialization_cache):
|
||||
"""Serialize all attributes in the layer."""
|
||||
save_model_default_signature = False
|
||||
if _KERAS_CACHE_KEY not in serialization_cache:
|
||||
keras_cache = serialization_cache[_KERAS_CACHE_KEY] = {}
|
||||
if isinstance(layer, training_lib.Model):
|
||||
# Only trace default signature if the root object is a Model. Since the
|
||||
# keras cache key is only created in this method, we know that the object
|
||||
# is root if the key does not yet exist in the cache.
|
||||
save_model_default_signature = True
|
||||
else:
|
||||
keras_cache = serialization_cache[_KERAS_CACHE_KEY]
|
||||
|
||||
if layer in keras_cache:
|
||||
return keras_cache[layer]
|
||||
serialized_attr = keras_cache[layer] = (
|
||||
serialized_attributes.SerializedAttributes.new(layer))
|
||||
|
||||
if _should_skip_serialization(layer):
|
||||
return serialized_attr
|
||||
|
||||
function_dict = {}
|
||||
if save_model_default_signature:
|
||||
# For compatibility with the tf.Lite Converter, the default save signature
|
||||
# should be traced without nested calls to other wrapped functions.
|
||||
# TODO(kathywu): Investigate why having nested calls results in a stateful
|
||||
# function. Perhaps something to do with losses, which are traced in nested
|
||||
# calls but not in the flat call.
|
||||
function_dict['_default_save_signature'] = _default_save_signature(layer)
|
||||
else:
|
||||
function_dict['_default_save_signature'] = None
|
||||
|
||||
object_dict = _wrap_layer_objects(layer, serialization_cache)
|
||||
try:
|
||||
function_dict.update(_wrap_layer_functions(layer, serialization_cache))
|
||||
except (ValueError, TypeError) as e:
|
||||
logging.warning('Skipping full serialization of object {}, because an '
|
||||
'error occurred while tracing layer functions. Error '
|
||||
'message: {}'.format(layer, e))
|
||||
else:
|
||||
# Add checkpointable objects and functions to the SerializedAttribute object
|
||||
# only if all functions are successfully traced.
|
||||
# The `set_and_validate_*` function ensures that all required attributes are
|
||||
# exported with the correct type.
|
||||
serialized_attr.set_and_validate_objects(object_dict)
|
||||
serialized_attr.set_and_validate_functions(function_dict)
|
||||
return serialized_attr
|
||||
|
||||
|
||||
def _should_skip_serialization(layer):
|
||||
"""Skip serializing extra objects and functions if layer inputs aren't set."""
|
||||
if isinstance(layer, training_lib.Model):
|
||||
try:
|
||||
# pylint:disable=pointless-statement
|
||||
layer.inputs
|
||||
layer.input_names
|
||||
# pylint:enable=pointless-statement
|
||||
except AttributeError:
|
||||
# If the model does not have inputs set, because it was not called or its
|
||||
# input shapes were not recorded, we won't have a signature so can't trace
|
||||
# a function. But the user may still save an object with this Model
|
||||
# attached; we won't fail the whole tf.saved_model.save.
|
||||
logging.warning('Skipping full serialization of Keras model {}, because '
|
||||
'its inputs are not defined.'.format(layer))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
if not layer.built:
|
||||
logging.warning('Skipping full serialization of Keras layer {}, because '
|
||||
'it is not built.'.format(layer))
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _wrap_layer_objects(layer, serialization_cache):
|
||||
"""Returns extra trackable objects to attach to the serialized layer.
|
||||
|
||||
Args:
|
||||
layer: Keras Layer object.
|
||||
serialization_cache: Dictionary shared between all objects during
|
||||
serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing all checkpointable objects from a
|
||||
SerializedAttributes object. See LayerAttributes and ModelAttributes for
|
||||
entire list of objects
|
||||
"""
|
||||
# Wrap all regularization losses as tf.functions.
|
||||
# First, generate list of all regularization losses in this layer and
|
||||
# sublayers.
|
||||
all_losses = layer._callable_losses[:] # pylint: disable=protected-access
|
||||
for child_layer in _list_all_layers(layer):
|
||||
all_losses.extend(child_layer._callable_losses) # pylint: disable=protected-access
|
||||
# Next, wrap all loss functions as tf.functions. Use the serialization cache
|
||||
# to store already-wrapped functions.
|
||||
keras_loss_cache = serialization_cache.setdefault('keras_losses', {})
|
||||
wrapped_loss_functions = []
|
||||
for loss_fn in all_losses:
|
||||
if loss_fn in keras_loss_cache:
|
||||
wrapped_loss_functions.append(keras_loss_cache[loss_fn])
|
||||
else:
|
||||
wrapped_loss = _wrap_unconditional_loss(loss_fn, len(keras_loss_cache))
|
||||
keras_loss_cache[loss_fn] = wrapped_loss
|
||||
wrapped_loss_functions.append(wrapped_loss)
|
||||
wrapped_layer_losses = [keras_loss_cache[fn]
|
||||
for fn in layer._callable_losses[:]] # pylint: disable=protected-access
|
||||
return dict(
|
||||
variables=data_structures.ListWrapper(layer.variables),
|
||||
trainable_variables=data_structures.ListWrapper(
|
||||
layer.trainable_variables),
|
||||
non_trainable_variables=data_structures.ListWrapper(
|
||||
layer.non_trainable_variables),
|
||||
layers=data_structures.ListWrapper(_list_all_layers(layer)),
|
||||
metrics=data_structures.ListWrapper(layer.metrics),
|
||||
regularization_losses=data_structures.ListWrapper(
|
||||
wrapped_loss_functions),
|
||||
layer_regularization_losses=data_structures.ListWrapper(
|
||||
wrapped_layer_losses))
|
||||
|
||||
|
||||
def _wrap_layer_functions(layer, serialization_cache):
|
||||
"""Returns dict of wrapped layer call function and losses in tf.functions.
|
||||
|
||||
Args:
|
||||
layer: Keras Layer object.
|
||||
serialization_cache: Dictionary shared between all objects during
|
||||
serialization.
|
||||
|
||||
Returns:
|
||||
A dictionary containing all keras tf.functions to serialize. See
|
||||
LayerAttributes and ModelAttributes for the list of all attributes.
|
||||
"""
|
||||
# Since Sequential models may be modified in place using model.add() or
|
||||
# model.pop(), don't use saved functions.
|
||||
if (isinstance(layer, keras_load.RevivedLayer) and
|
||||
not isinstance(layer, keras_load.RevivedSequential)):
|
||||
return {fn_name: getattr(layer.keras_api, fn_name, None)
|
||||
for fn_name in serialized_attributes.LayerAttributes.all_functions}
|
||||
|
||||
# Reset the losses of the layer and its children. The call function in each
|
||||
# child layer is replaced with tf.functions.
|
||||
original_fns = _replace_child_layer_functions(layer, serialization_cache)
|
||||
original_losses = _reset_layer_losses(layer)
|
||||
|
||||
# Wrap all the layer call and activity regularizer functions.
|
||||
|
||||
# Use LayerCallCollection to ensure that all layer call functions (__call__,
|
||||
# call with losses) are traced with the same inputs.
|
||||
call_collection = LayerCallCollection(layer)
|
||||
call_fn_with_losses = call_collection.add_function(
|
||||
_wrap_call_and_conditional_losses(layer),
|
||||
'{}_layer_call_and_return_conditional_losses'.format(layer.name))
|
||||
call_fn = call_collection.add_function(
|
||||
_extract_outputs_from_fn(layer, call_fn_with_losses),
|
||||
'{}_layer_call_fn'.format(layer.name))
|
||||
|
||||
fns = {'call_and_return_conditional_losses': call_fn_with_losses,
|
||||
'__call__': call_fn}
|
||||
|
||||
if layer.activity_regularizer is not None:
|
||||
fns['activity_regularizer_fn'] = _wrap_activity_regularizer(layer)
|
||||
fns['call_and_return_all_conditional_losses'] = (
|
||||
call_collection.add_function(
|
||||
_append_activity_regularizer_loss(call_fn_with_losses,
|
||||
fns['activity_regularizer_fn']),
|
||||
'{}_layer_call_and_return_all_conditional_losses'.format(layer.name)
|
||||
))
|
||||
else:
|
||||
fns['activity_regularizer_fn'] = None
|
||||
fns['call_and_return_all_conditional_losses'] = call_fn_with_losses
|
||||
|
||||
# Manually trigger traces before restoring the overwritten functions. The
|
||||
# functions are traced within the layer call context to ensure that layer
|
||||
# functions (e.g. add_loss) behave as though running in graph mode.
|
||||
with base_layer_utils.call_context().enter(layer, None, True, None):
|
||||
for fn in fns.values():
|
||||
if fn is not None and fn.input_signature is not None:
|
||||
fn.get_concrete_function()
|
||||
|
||||
# Restore overwritten functions and losses
|
||||
_restore_child_layer_functions(original_fns)
|
||||
_restore_layer_losses(original_losses)
|
||||
|
||||
return fns
|
||||
|
||||
|
||||
def _default_save_signature(layer):
|
||||
original_losses = _reset_layer_losses(layer)
|
||||
fn = saving_utils.trace_model_call(layer)
|
||||
fn.get_concrete_function()
|
||||
_restore_layer_losses(original_losses)
|
||||
return fn
|
||||
|
||||
|
||||
def _list_all_layers(obj):
|
||||
if isinstance(obj, training_lib.Model):
|
||||
return obj.layers
|
||||
else:
|
||||
return trackable_layer_utils.filter_empty_layer_containers(obj._layers) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _replace_child_layer_functions(layer, serialization_cache):
|
||||
"""Replaces functions in the children layers with wrapped tf.functions.
|
||||
|
||||
This step allows functions from parent layers to reference the wrapped
|
||||
functions from their children layers instead of retracing the ops.
|
||||
|
||||
This function also resets all losses stored in the layer. These are stored in
|
||||
the returned dictionary. Use `_restore_child_layer_functions` to restore
|
||||
the original attributes.
|
||||
|
||||
Args:
|
||||
layer: Keras Layer object.
|
||||
serialization_cache: Dictionary shared between all objects during
|
||||
serialization.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping layer objects -> original functions and losses:
|
||||
{ Child layer 1: {
|
||||
'losses': Original losses,
|
||||
'call': Original call function
|
||||
'activity_regularizer': Original activity regularizer},
|
||||
Child layer 2: ...
|
||||
}
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
original_fns = {}
|
||||
for child_layer in _list_all_layers(layer):
|
||||
if child_layer not in serialization_cache[_KERAS_CACHE_KEY]:
|
||||
layer_fns = (serialize_all_attributes(child_layer, serialization_cache)
|
||||
.functions)
|
||||
else:
|
||||
layer_fns = serialization_cache[_KERAS_CACHE_KEY][child_layer].functions
|
||||
if not layer_fns:
|
||||
# This indicates either:
|
||||
# - circular dependency, which means the current layer's functions
|
||||
# should be wrapped first.
|
||||
# - Child layer's inputs are not defined, so its functions have not been
|
||||
# wrapped. In this case, no replacement is necessary so move on to the
|
||||
# next child.
|
||||
continue
|
||||
original_fns[child_layer] = {
|
||||
'call': child_layer.call,
|
||||
'activity_regularizer': child_layer.activity_regularizer
|
||||
}
|
||||
with trackable.no_automatic_dependency_tracking_scope(child_layer):
|
||||
try:
|
||||
child_layer.activity_regularizer = layer_fns.get(
|
||||
'activity_regularizer_fn')
|
||||
except AttributeError:
|
||||
# Some layers have an unsettable activity regularizer.
|
||||
pass
|
||||
child_layer.call = utils.use_wrapped_call(
|
||||
child_layer, layer_fns['call_and_return_conditional_losses'])
|
||||
return original_fns
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def _restore_child_layer_functions(original_fns):
|
||||
"""Restores attributes replaced with `_replace_child_layer_functions`."""
|
||||
for child_layer, fns in original_fns.items():
|
||||
with trackable.no_automatic_dependency_tracking_scope(child_layer):
|
||||
child_layer.call = fns['call']
|
||||
try:
|
||||
child_layer.activity_regularizer = fns['activity_regularizer']
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def _reset_layer_losses(parent_layer):
|
||||
"""Resets losses of layer and its sublayers, and returns original losses."""
|
||||
losses_dict = {}
|
||||
for layer in _list_all_layers(parent_layer) + [parent_layer]:
|
||||
losses_dict[layer] = {'losses': layer._losses[:],
|
||||
'eager_losses': layer._eager_losses[:]}
|
||||
with trackable.no_automatic_dependency_tracking_scope(layer):
|
||||
layer._losses = []
|
||||
layer._eager_losses = []
|
||||
return losses_dict
|
||||
|
||||
|
||||
def _restore_layer_losses(losses_dict):
|
||||
for layer in losses_dict:
|
||||
with trackable.no_automatic_dependency_tracking_scope(layer):
|
||||
layer._losses = losses_dict[layer]['losses']
|
||||
layer._eager_losses = losses_dict[layer]['eager_losses']
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class LayerCallCollection(object):
|
||||
"""Groups wrapped layer call functions.
|
||||
|
||||
This is used to ensure that all layer call functions are traced with the same
|
||||
inputs-
|
||||
- call
|
||||
- call_and_return_conditional_losses
|
||||
- call_and_return_all_conditional_losses
|
||||
"""
|
||||
|
||||
def __init__(self, layer):
|
||||
self._layer = layer
|
||||
self._expects_training_arg = layer._expects_training_arg # pylint: disable=protected-access
|
||||
self._input_signature = self._generate_input_signature(layer)
|
||||
self._functions = weakref.WeakValueDictionary()
|
||||
# Bool indicating whether this object is currently tracing the layer call
|
||||
# functions.
|
||||
self.tracing = False
|
||||
|
||||
def _generate_input_signature(self, layer):
|
||||
"""Inspects layer object and returns the inferred input signature.
|
||||
|
||||
Args:
|
||||
layer: Layer object.
|
||||
|
||||
Returns:
|
||||
List of possibly nested TensorSpecs of the layer call function inputs.
|
||||
The list does not contain the `training` argument.
|
||||
"""
|
||||
if (isinstance(layer.call, def_function.Function) and
|
||||
layer.call.input_signature is not None):
|
||||
return layer.call.input_signature
|
||||
else:
|
||||
if isinstance(layer, training_lib.Model):
|
||||
return saving_utils.model_input_signature(layer)
|
||||
elif layer.input_spec is not None:
|
||||
|
||||
def to_tensor_spec_or_none(x):
|
||||
spec = input_spec.to_tensor_spec(x, layer.dtype)
|
||||
# If the shape is too general (e.g. multiple dimensions are allowed),
|
||||
# return None so that separate functions can be generated for each
|
||||
# inferred input signature.
|
||||
# TODO(b/134962016): currently partial signatures are not supported.
|
||||
if spec.shape == tensor_shape.TensorShape(None):
|
||||
return None
|
||||
return spec
|
||||
input_signature = [nest.map_structure(
|
||||
to_tensor_spec_or_none, layer.input_spec)]
|
||||
|
||||
return input_signature
|
||||
else:
|
||||
return None
|
||||
|
||||
def add_trace(self, *args, **kwargs):
|
||||
"""Traces all functions with the same args and kwargs.
|
||||
|
||||
Args:
|
||||
*args: Positional args passed to the original function.
|
||||
**kwargs: Keyword args passed to the original function.
|
||||
"""
|
||||
kwargs = kwargs.copy()
|
||||
self.tracing = True
|
||||
for fn in self._functions.values():
|
||||
# TODO(kathywu): Replace arguments with broader shapes defined in the
|
||||
# input signature.
|
||||
if self._expects_training_arg:
|
||||
kwargs['training'] = False
|
||||
fn.original_get_concrete_function(*args, **kwargs)
|
||||
kwargs['training'] = True
|
||||
fn.original_get_concrete_function(*args, **kwargs)
|
||||
else:
|
||||
fn.original_get_concrete_function(*args, **kwargs)
|
||||
self.tracing = False
|
||||
|
||||
@property
|
||||
def fn_input_signature(self):
|
||||
"""Returns input signature for the wrapped layer call function."""
|
||||
if self._expects_training_arg:
|
||||
# The training arg is left as a python boolean, so the call functions
|
||||
# will not have an input signature (input signatures may only describe
|
||||
# tensor arguments).
|
||||
return None
|
||||
if None in nest.flatten(self._input_signature):
|
||||
# TODO(b/134962016): If input signature cannot be partially defined.
|
||||
return None
|
||||
return self._input_signature
|
||||
|
||||
def add_function(self, python_function, name):
|
||||
"""Adds a layer call function to the collection."""
|
||||
self._functions[name] = fn = LayerCall(
|
||||
self, python_function, name,
|
||||
input_signature=self.fn_input_signature)
|
||||
|
||||
if (None not in nest.flatten(self._input_signature) and
|
||||
self._expects_training_arg):
|
||||
# Manually add traces for layers that expect a training argument and have
|
||||
# a fully defined input signature.
|
||||
self.add_trace(*self._input_signature)
|
||||
return fn
|
||||
|
||||
|
||||
class LayerCall(def_function.Function):
|
||||
"""Function that triggers traces of other functions in the same collection."""
|
||||
|
||||
def __init__(self, call_collection, *args, **kwargs):
|
||||
super(LayerCall, self).__init__(*args, **kwargs)
|
||||
self.call_collection = call_collection
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if not self.call_collection.tracing:
|
||||
self.call_collection.add_trace(*args, **kwargs)
|
||||
return super(LayerCall, self).__call__(*args, **kwargs)
|
||||
|
||||
def get_concrete_function(self, *args, **kwargs):
|
||||
if not self.call_collection.tracing:
|
||||
self.call_collection.add_trace(*args, **kwargs)
|
||||
return super(LayerCall, self).get_concrete_function(*args, **kwargs)
|
||||
|
||||
def original_get_concrete_function(self, *args, **kwargs):
|
||||
return super(LayerCall, self).get_concrete_function(*args, **kwargs)
|
||||
|
||||
|
||||
def _wrap_call_and_conditional_losses(layer):
|
||||
"""Wraps call function that returns a tuple of (outputs, losses).
|
||||
|
||||
The losses returned are conditional on the inputs passed to the call function.
|
||||
Unconditional losses (e.g. weight regularizeration) are wrapped separately.
|
||||
|
||||
Args:
|
||||
layer: a Keras layer object
|
||||
|
||||
Returns:
|
||||
python call function that returns outputs and conditional losses -- excludes
|
||||
activity regularizer
|
||||
"""
|
||||
# Create function that generates both outputs and losses
|
||||
layer_call = layer.call
|
||||
if layer._expects_training_arg: # pylint: disable=protected-access
|
||||
def call_and_return_conditional_losses(inputs, training=False):
|
||||
return layer_call(inputs, training=training), layer.get_losses_for(inputs)
|
||||
else:
|
||||
def call_and_return_conditional_losses(inputs):
|
||||
K.set_learning_phase(0)
|
||||
return layer_call(inputs), layer.get_losses_for(inputs)
|
||||
return call_and_return_conditional_losses
|
||||
|
||||
|
||||
def _extract_outputs_from_fn(layer, call_and_return_conditional_losses):
|
||||
"""Returns a function that returns only call function outputs."""
|
||||
if isinstance(layer, keras_load.RevivedLayer):
|
||||
return layer.keras_api.__call__ # pylint: disable=protected-access
|
||||
if layer._expects_training_arg: # pylint: disable=protected-access
|
||||
def call(inputs, training=False):
|
||||
return call_and_return_conditional_losses(inputs, training=training)[0]
|
||||
else:
|
||||
def call(inputs):
|
||||
return call_and_return_conditional_losses(inputs)[0]
|
||||
return call
|
||||
|
||||
|
||||
def _append_activity_regularizer_loss(
|
||||
call_fn_with_losses, activity_regularizer_fn):
|
||||
"""Appends activity regularizer loss to losses returned by the wrapped fn."""
|
||||
def fn(*args, **kwargs):
|
||||
outputs, losses = call_fn_with_losses(*args, **kwargs)
|
||||
losses.append(activity_regularizer_fn(outputs))
|
||||
return outputs, losses
|
||||
return fn
|
||||
|
||||
|
||||
def _wrap_unconditional_loss(loss_fn, index):
|
||||
"""Wraps callable/unconditonal loss, returning a serializable function."""
|
||||
# Extract original loss function from partial function
|
||||
fn = loss_fn.args[0] if isinstance(loss_fn, functools.partial) else loss_fn
|
||||
if isinstance(fn, def_function.Function):
|
||||
return fn
|
||||
else:
|
||||
return def_function.Function(
|
||||
fn, 'loss_fn_{}'.format(index), input_signature=[])
|
||||
|
||||
|
||||
def _wrap_activity_regularizer(layer):
|
||||
"""Wraps the activity regularizer."""
|
||||
if isinstance(layer.activity_regularizer, def_function.Function):
|
||||
return layer.activity_regularizer
|
||||
return def_function.Function(
|
||||
layer.activity_regularizer,
|
||||
'{}_activity_regularizer'.format(layer.name),
|
||||
input_signature=[tensor_spec.TensorSpec(None, layer.dtype or K.floatx())])
|
311
tensorflow/python/keras/saving/saved_model/saved_model_test.py
Normal file
311
tensorflow/python/keras/saving/saved_model/saved_model_test.py
Normal file
@ -0,0 +1,311 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=protected-access
|
||||
"""Tests for saving/loading function for keras Model."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.saving.saved_model import load as saved_model_load
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import load as tf_load
|
||||
from tensorflow.python.saved_model import save as tf_save
|
||||
|
||||
|
||||
class LayerWithLearningPhase(keras.engine.base_layer.Layer):
|
||||
|
||||
def build(self, input_shape):
|
||||
self.input_spec = keras.layers.InputSpec(shape=[None] * len(input_shape))
|
||||
self.built = True
|
||||
|
||||
def call(self, x, training=None):
|
||||
if training is None:
|
||||
training = keras.backend.learning_phase()
|
||||
output = tf_utils.smart_cond(
|
||||
training, lambda: x * 0, lambda: array_ops.identity(x))
|
||||
if not context.executing_eagerly():
|
||||
output._uses_learning_phase = True # pylint: disable=protected-access
|
||||
return output
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
return input_shape
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
|
||||
def _save_model_dir(self, dirname='saved_model'):
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
|
||||
return os.path.join(temp_dir, dirname)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_model_save_and_load(self):
|
||||
input_arr = np.random.random((1, 3)).astype(np.float32)
|
||||
target_arr = np.random.random((1, 4)).astype(np.float32)
|
||||
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
model.layers[-1].activity_regularizer = regularizers.get('l2')
|
||||
model.activity_regularizer = regularizers.get('l2')
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer='rmsprop')
|
||||
model.train_on_batch(input_arr, target_arr)
|
||||
|
||||
def callable_loss():
|
||||
return math_ops.reduce_sum(model.weights[0])
|
||||
model.add_loss(callable_loss)
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
|
||||
loaded = saved_model_load.load(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
self.assertAllClose(self.evaluate(model.weights),
|
||||
self.evaluate(loaded.weights))
|
||||
|
||||
input_arr = constant_op.constant(
|
||||
np.random.random((1, 3)).astype(np.float32))
|
||||
self.assertAllClose(self.evaluate(model(input_arr)),
|
||||
self.evaluate(loaded(input_arr)))
|
||||
# Validate losses. The order of conditional losses may change between the
|
||||
# model and loaded model, so sort the losses first.
|
||||
if context.executing_eagerly():
|
||||
self.assertAllClose(sorted(self.evaluate(model.losses)),
|
||||
sorted(self.evaluate(loaded.losses)))
|
||||
else:
|
||||
self.assertAllClose(self.evaluate(model.get_losses_for(None)),
|
||||
self.evaluate(loaded.get_losses_for(None)))
|
||||
self.assertAllClose(
|
||||
sorted(self.evaluate(model.get_losses_for(input_arr))),
|
||||
sorted(self.evaluate(loaded.get_losses_for(input_arr))))
|
||||
|
||||
def test_trainable_weights(self):
|
||||
layer = keras.layers.Dense(4, name='custom_layer')
|
||||
layer.build([3,])
|
||||
layer.add_weight(
|
||||
'extra_weight', shape=[],
|
||||
initializer=init_ops.constant_initializer(11),
|
||||
trainable=True)
|
||||
layer.add_weight(
|
||||
'extra_weight_2', shape=[],
|
||||
initializer=init_ops.constant_initializer(12),
|
||||
trainable=False)
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
self.evaluate(variables.variables_initializer(layer.variables))
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
loaded = saved_model_load.load(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
|
||||
equal_attrs = ['name', '_expects_training_arg', 'trainable']
|
||||
for attr in equal_attrs:
|
||||
self.assertEqual(getattr(layer, attr), getattr(loaded, attr))
|
||||
|
||||
all_close = ['weights', 'trainable_weights', 'non_trainable_weights']
|
||||
for attr in all_close:
|
||||
self.assertAllClose(self.evaluate(getattr(layer, attr)),
|
||||
self.evaluate(getattr(loaded, attr)))
|
||||
|
||||
def test_maintains_losses(self):
|
||||
"""Tests that the layer losses do not change before and after export."""
|
||||
|
||||
class LayerWithLoss(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
self.add_loss(math_ops.reduce_sum(inputs), inputs)
|
||||
return inputs
|
||||
|
||||
model = keras.models.Sequential([LayerWithLoss()])
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer='rmsprop')
|
||||
input_arr = np.random.random((1, 3)).astype(np.float32)
|
||||
target_arr = np.random.random((1, 3)).astype(np.float32)
|
||||
|
||||
# Test that symbolic losses are maintained (train_on_batch saves symbolic
|
||||
# losses.)
|
||||
model.train_on_batch(input_arr, target_arr)
|
||||
previous_losses = model.losses[:]
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
self.assertAllEqual(previous_losses, model.losses)
|
||||
|
||||
if context.executing_eagerly():
|
||||
# Test that eager losses are maintained.
|
||||
model(input_arr) # Calls model eagerly, creating eager losses.
|
||||
previous_losses = model.losses[:]
|
||||
tf_save.save(model, saved_model_dir)
|
||||
self.assertAllEqual(previous_losses, model.losses)
|
||||
|
||||
def test_layer_with_learning_phase(self):
|
||||
layer = LayerWithLearningPhase()
|
||||
layer.build([None, None])
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
loaded = saved_model_load.load(saved_model_dir)
|
||||
input_arr = array_ops.ones((4, 3))
|
||||
|
||||
# Run the layer, and use the keras backend learing phase
|
||||
keras.backend.set_learning_phase(0)
|
||||
self.assertAllEqual(input_arr, loaded(input_arr))
|
||||
keras.backend.set_learning_phase(1)
|
||||
self.assertAllEqual(array_ops.zeros((4, 3)), loaded(input_arr))
|
||||
|
||||
# Run the layer while explicitly setting the training argument
|
||||
self.assertAllEqual(
|
||||
input_arr, loaded(input_arr, training=constant_op.constant(False)))
|
||||
self.assertAllEqual(
|
||||
array_ops.zeros((4, 3)),
|
||||
loaded(input_arr, training=constant_op.constant(True)))
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_standard_loader(self):
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
model.activity_regularizer = regularizers.get('l2')
|
||||
def eager_loss():
|
||||
return math_ops.reduce_sum(model.weights[0])
|
||||
model.add_loss(eager_loss)
|
||||
|
||||
# Call predict to ensure that all layers are built and inputs are set.
|
||||
model.predict(np.random.random((1, 3)))
|
||||
saved_model_dir = self._save_model_dir()
|
||||
|
||||
tf_save.save(model, saved_model_dir)
|
||||
|
||||
loaded = tf_load.load(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
all_close = ['variables', 'trainable_variables',
|
||||
'non_trainable_variables']
|
||||
for attr in all_close:
|
||||
self.assertAllClose(self.evaluate(getattr(model, attr)),
|
||||
self.evaluate(getattr(loaded.keras_api, attr)))
|
||||
self.assertLen(loaded.regularization_losses, 1)
|
||||
expected_layers = len(model.layers)
|
||||
self.assertEqual(expected_layers, len(loaded.keras_api.layers))
|
||||
input_arr = array_ops.ones((4, 3))
|
||||
self.assertAllClose(self.evaluate(model(input_arr)),
|
||||
self.evaluate(loaded(input_arr)))
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_compiled_model(self):
|
||||
input_arr = np.random.random((1, 3))
|
||||
target_arr = np.random.random((1, 4))
|
||||
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
expected_predict = model.predict(input_arr)
|
||||
|
||||
# Compile and save model.
|
||||
model.compile('rmsprop', 'mse')
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
|
||||
# TODO(b/134519980): Issue with model.fit if the model call function uses
|
||||
# a tf.function (Graph mode only).
|
||||
with context.eager_mode():
|
||||
loaded = saved_model_load.load(saved_model_dir)
|
||||
actual_predict = loaded.predict(input_arr)
|
||||
self.assertAllClose(expected_predict, actual_predict)
|
||||
|
||||
loss_before = loaded.evaluate(input_arr, target_arr)
|
||||
loaded.fit(input_arr, target_arr)
|
||||
loss_after = loaded.evaluate(input_arr, target_arr)
|
||||
self.assertLess(loss_after, loss_before)
|
||||
predict = loaded.predict(input_arr)
|
||||
|
||||
ckpt_path = os.path.join(self.get_temp_dir(), 'weights')
|
||||
loaded.save_weights(ckpt_path)
|
||||
|
||||
# Ensure that the checkpoint is compatible with the original model.
|
||||
model.load_weights(ckpt_path)
|
||||
self.assertAllClose(predict, model.predict(input_arr))
|
||||
|
||||
def test_metadata_input_spec(self):
|
||||
class LayerWithNestedSpec(keras.layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(LayerWithNestedSpec, self).__init__()
|
||||
self.input_spec = {
|
||||
'a': keras.layers.InputSpec(max_ndim=3, axes={-1: 2}),
|
||||
'b': keras.layers.InputSpec(shape=(None, 2, 3), dtype='float16')}
|
||||
|
||||
layer = LayerWithNestedSpec()
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
loaded = saved_model_load.load(saved_model_dir)
|
||||
self.assertEqual(3, loaded.input_spec['a'].max_ndim)
|
||||
self.assertEqual({-1: 2}, loaded.input_spec['a'].axes)
|
||||
self.assertAllEqual([None, 2, 3], loaded.input_spec['b'].shape)
|
||||
self.assertEqual('float16', loaded.input_spec['b'].dtype)
|
||||
|
||||
def test_multi_input_model(self):
|
||||
input_1 = keras.layers.Input(shape=(3,))
|
||||
input_2 = keras.layers.Input(shape=(5,))
|
||||
model = keras.Model([input_1, input_2], [input_1, input_2])
|
||||
saved_model_dir = self._save_model_dir()
|
||||
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded = saved_model_load.load(saved_model_dir)
|
||||
input_arr_1 = np.random.random((1, 3)).astype('float32')
|
||||
input_arr_2 = np.random.random((1, 5)).astype('float32')
|
||||
|
||||
outputs = loaded([input_arr_1, input_arr_2])
|
||||
self.assertAllEqual(input_arr_1, outputs[0])
|
||||
self.assertAllEqual(input_arr_2, outputs[1])
|
||||
|
||||
def test_revived_sequential(self):
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(5, input_shape=(3,),
|
||||
kernel_regularizer=regularizers.get('l2')))
|
||||
model.add(keras.layers.Dense(2, kernel_regularizer=regularizers.get('l2')))
|
||||
|
||||
self.evaluate(variables.variables_initializer(model.variables))
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded = saved_model_load.load(saved_model_dir)
|
||||
|
||||
self.assertLen(loaded.layers, 2)
|
||||
self.assertLen(loaded.losses, 2)
|
||||
|
||||
loaded.pop()
|
||||
|
||||
self.assertLen(loaded.layers, 1)
|
||||
self.assertLen(loaded.losses, 1)
|
||||
|
||||
loaded.add(keras.layers.Dense(2, kernel_regularizer=regularizers.get('l2')))
|
||||
|
||||
self.assertLen(loaded.layers, 2)
|
||||
self.assertLen(loaded.losses, 2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
@ -0,0 +1,267 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Helper classes that list&validate all attributes to serialize to SavedModel.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.keras.saving.saved_model import constants
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.training.tracking.tracking import AutoTrackable
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
|
||||
# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
|
||||
# once the issue with copybara is fixed.
|
||||
# pylint:disable=g-inconsistent-quotes
|
||||
base_layer = LazyLoader(
|
||||
"base_layer", globals(),
|
||||
"tensorflow.python.keras.engine.base_layer")
|
||||
training_lib = LazyLoader(
|
||||
"training_lib", globals(),
|
||||
"tensorflow.python.keras.engine.training")
|
||||
# pylint:enable=g-inconsistent-quotes
|
||||
|
||||
|
||||
class SerializedAttributes(object):
|
||||
"""Class that tracks and validates all serialization attributes.
|
||||
|
||||
Keras models contain many Python-defined components. For example, the
|
||||
trainable_variable property lists the model's trainable variables by
|
||||
recursively retrieving the trainable variables from each of the child layers.
|
||||
Another example is model.call, a python function that calls child layers and
|
||||
adds ops to the backend graph.
|
||||
|
||||
Only Tensorflow checkpointable objects and functions can be serialized to
|
||||
SavedModel. Serializing a Keras model as-is results in a checkpointable object
|
||||
that does not resemble a Keras model at all. Thus, extra checkpointable
|
||||
objects and functions must be created during serialization.
|
||||
|
||||
**Defining new serialized attributes**
|
||||
Child classes should be defined using:
|
||||
SerializedAttributes.with_attributes(
|
||||
'name', checkpointable_objects=[...], functions=[...], copy_from=[...])
|
||||
This class is used to cache generated checkpointable objects and functions,
|
||||
ensuring that new objects and functions are generated a single time.
|
||||
|
||||
**Usage during serialization**
|
||||
Each Layer/Model object should have a corresponding instance of
|
||||
SerializedAttributes. Create a new instance by calling
|
||||
`SerializedAttributes.new(obj)`. Objects and functions may be saved using
|
||||
`.set_and_validate_checkpointable_objects`/`.set_and_and_validate_functions`.
|
||||
The properties `.checkpointable_objects` and `.functions` returns the cached
|
||||
values.
|
||||
|
||||
**Adding/changing attributes to save to SavedModel**
|
||||
1. Change the call to `SerializedAttributes.with_attributes` in the correct
|
||||
class:
|
||||
- CommonEndpoints: Base attributes to be added during serialization. If
|
||||
these attributes are present in a Trackable object, it can be
|
||||
deserialized to a Keras Model.
|
||||
- LayerAttributes: Attributes to serialize for Layer objects.
|
||||
- ModelAttributes: Attributes to serialize for Model objects.
|
||||
2. Update class docstring
|
||||
3. Update arguments to any calls to `set_and_validate_*`. For example, if
|
||||
`call_raw_tensors` is added to the ModelAttributes function list, then
|
||||
a `call_raw_tensors` function should be passed to
|
||||
`set_and_validate_functions`.
|
||||
|
||||
**Common endpoints vs other attributes**
|
||||
Only common endpoints are attached directly to the root object. Keras-specific
|
||||
attributes are saved to a separate trackable object with the name "keras_api".
|
||||
The number of objects attached to the root is limited because any naming
|
||||
conflicts will cause user code to break.
|
||||
|
||||
Another reason is that this will only affect users who call
|
||||
`tf.saved_model.load` instead of `tf.keras.models.load_model`. These are
|
||||
advanced users who are likely to have defined their own tf.functions and
|
||||
trackable objects. The added Keras-specific attributes are kept out of the way
|
||||
in the "keras_api" namespace.
|
||||
|
||||
Properties defined in this class may be used to filter out keras-specific
|
||||
attributes:
|
||||
- `functions_to_serialize`: Returns dict of functions to attach to the root
|
||||
object.
|
||||
- `checkpointable_objects_to_serialize`: Returns dict of objects to attach to
|
||||
the root object (including separate trackable object containing
|
||||
keras-specific attributes)
|
||||
|
||||
All changes to the serialized attributes must be backwards-compatible, so
|
||||
attributes should not be removed or modified without sufficient justification.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def with_attributes(
|
||||
name, checkpointable_objects=None, functions=None, copy_from=None):
|
||||
"""Creates a subclass with all attributes as specified in the arguments.
|
||||
|
||||
Args:
|
||||
name: Name of subclass
|
||||
checkpointable_objects: List of checkpointable objects to be serialized
|
||||
in the SavedModel.
|
||||
functions: List of functions to be serialized in the SavedModel.
|
||||
copy_from: List of other SerializedAttributes subclasses. The returend
|
||||
class will copy checkpoint objects/functions from each subclass.
|
||||
|
||||
Returns:
|
||||
Child class with attributes as defined in the `checkpointable_objects`
|
||||
and `functions` lists.
|
||||
"""
|
||||
checkpointable_objects = checkpointable_objects or []
|
||||
functions = functions or []
|
||||
|
||||
if copy_from is not None:
|
||||
for cls in copy_from:
|
||||
checkpointable_objects.extend(cls.all_checkpointable_objects)
|
||||
functions.extend(cls.all_functions)
|
||||
|
||||
classdict = {
|
||||
'all_checkpointable_objects': set(checkpointable_objects),
|
||||
'all_functions': set(functions)}
|
||||
return type(name, (SerializedAttributes,), classdict)
|
||||
|
||||
@staticmethod
|
||||
def new(obj):
|
||||
if isinstance(obj, training_lib.Model):
|
||||
return ModelAttributes()
|
||||
elif isinstance(obj, base_layer.Layer):
|
||||
return LayerAttributes()
|
||||
else:
|
||||
raise TypeError('Internal error during serialization: Expected Keras '
|
||||
'Layer object, got {} of type {}'.format(obj, type(obj)))
|
||||
|
||||
def __init__(self):
|
||||
self._object_dict = {}
|
||||
self._function_dict = {}
|
||||
self._keras_trackable = AutoTrackable()
|
||||
|
||||
@property
|
||||
def functions(self):
|
||||
"""Returns dictionary of all functions."""
|
||||
return {key: value for key, value in self._function_dict.items()
|
||||
if value is not None}
|
||||
|
||||
@property
|
||||
def checkpointable_objects(self):
|
||||
"""Returns dictionary of all checkpointable objects."""
|
||||
return {key: value for key, value in self._object_dict.items()
|
||||
if value is not None}
|
||||
|
||||
@property
|
||||
def functions_to_serialize(self):
|
||||
"""Returns functions to attach to the root object during serialization."""
|
||||
return {key: value for key, value in self.functions.items()
|
||||
if key in CommonEndpoints.all_functions}
|
||||
|
||||
@property
|
||||
def objects_to_serialize(self):
|
||||
"""Returns objects to attach to the root object during serialization."""
|
||||
objects = {key: value for key, value in self.checkpointable_objects.items()
|
||||
if key in CommonEndpoints.all_checkpointable_objects}
|
||||
objects[constants.KERAS_ATTR] = self._keras_trackable
|
||||
return objects
|
||||
|
||||
def set_and_validate_functions(self, function_dict):
|
||||
"""Saves function dictionary, and validates dictionary values."""
|
||||
for key in self.all_functions:
|
||||
if key in function_dict:
|
||||
if (function_dict[key] is not None and # Not all functions are required
|
||||
not isinstance(function_dict[key],
|
||||
(defun.Function, def_function.Function))):
|
||||
raise ValueError(
|
||||
'Function dictionary contained a non-function object: {} (for key'
|
||||
' {})'.format(function_dict[key], key))
|
||||
self._function_dict[key] = function_dict[key]
|
||||
setattr(self._keras_trackable, key, function_dict[key])
|
||||
else:
|
||||
raise ValueError('Function {} missing from serialized function dict.'
|
||||
.format(key))
|
||||
return self.functions
|
||||
|
||||
def set_and_validate_objects(self, object_dict):
|
||||
"""Saves objects to a dictionary, and validates the values."""
|
||||
for key in self.all_checkpointable_objects:
|
||||
if key in object_dict:
|
||||
if not isinstance(object_dict[key], trackable.Trackable):
|
||||
raise ValueError(
|
||||
'Object dictionary contained a non-trackable object: {} (for key'
|
||||
' {})'.format(object_dict[key], key))
|
||||
self._object_dict[key] = object_dict[key]
|
||||
setattr(self._keras_trackable, key, object_dict[key])
|
||||
else:
|
||||
raise ValueError('Object {} missing from serialized object dict.')
|
||||
return self.checkpointable_objects
|
||||
|
||||
|
||||
class CommonEndpoints(SerializedAttributes.with_attributes(
|
||||
'CommonEndpoints',
|
||||
checkpointable_objects=['variables', 'trainable_variables',
|
||||
'regularization_losses'],
|
||||
functions=['__call__', 'call_and_return_all_conditional_losses',
|
||||
'_default_save_signature'])):
|
||||
"""Common endpoints shared by all models loadable by Keras.
|
||||
|
||||
List of all attributes:
|
||||
variables: List of all variables in the model and its sublayers.
|
||||
trainable_variables: List of all trainable variables in the model and its
|
||||
sublayers.
|
||||
regulariation_losses: List of all unconditional losses (losses not dependent
|
||||
on the inputs) in the model and its sublayers.
|
||||
__call__: Function that takes inputs and returns the outputs of the model
|
||||
call function.
|
||||
call_and_return_all_conditional_losses: Function that returns a tuple of
|
||||
(call function outputs, list of all losses that depend on the inputs).
|
||||
_default_save_signature: Traced model call function. This is only included
|
||||
if the top level exported object is a Keras model.
|
||||
"""
|
||||
|
||||
|
||||
class LayerAttributes(SerializedAttributes.with_attributes(
|
||||
'LayerAttributes',
|
||||
checkpointable_objects=['non_trainable_variables', 'layers', 'metrics',
|
||||
'layer_regularization_losses'],
|
||||
functions=['call_and_return_conditional_losses', 'activity_regularizer_fn'],
|
||||
copy_from=[CommonEndpoints]
|
||||
)):
|
||||
"""Layer checkpointable objects + functions that are saved to the SavedModel.
|
||||
|
||||
List of all attributes:
|
||||
All attributes from CommonEndpoints
|
||||
non_trainable_variables: List of non-trainable variables in the layer and
|
||||
its sublayers.
|
||||
layers: List of all sublayers.
|
||||
metrics: List of all metrics in the layer and its sublayers.
|
||||
call_and_return_conditional_losses: Function that takes inputs and returns a
|
||||
tuple of (outputs of the call function, list of input-dependent losses).
|
||||
The list of losses excludes the activity regularizer function, which is
|
||||
separate to allow the deserialized Layer object to define a different
|
||||
activity regularizer.
|
||||
activity_regularizer_fn: Callable that returns the activity regularizer loss
|
||||
layer_regularization_losses: List of losses owned only by this layer.
|
||||
"""
|
||||
|
||||
|
||||
class ModelAttributes(SerializedAttributes.with_attributes(
|
||||
'ModelAttributes',
|
||||
copy_from=[LayerAttributes])):
|
||||
"""Model checkpointable objects + functions that are saved to the SavedModel.
|
||||
|
||||
List of all attributes:
|
||||
All attributes from LayerAttributes (including CommonEndpoints)
|
||||
"""
|
||||
# TODO(kathywu): Add attributes `compile_losses` and `compile_metrics`, which
|
||||
# list all losses and metrics defined by `model.compile`.
|
51
tensorflow/python/keras/saving/saved_model/utils.py
Normal file
51
tensorflow/python/keras/saving/saved_model/utils.py
Normal file
@ -0,0 +1,51 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utility functions shared between SavedModel saving/loading implementations."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
|
||||
|
||||
def use_wrapped_call(layer, call_fn):
|
||||
"""Creates fn that adds the losses returned by call_fn & returns the outputs.
|
||||
|
||||
Args:
|
||||
layer: A Keras layer object
|
||||
call_fn: tf.function that takes layer inputs (and possibly a training arg),
|
||||
and returns a tuple of (outputs, list of losses).
|
||||
|
||||
Returns:
|
||||
function that calls call_fn and returns the outputs. Losses returned by
|
||||
call_fn are added to the layer losses.
|
||||
"""
|
||||
# TODO(kathywu): Support mask argument and multi-input call functions.
|
||||
def wrapped_call(inputs, **kwargs):
|
||||
"""Returns the outputs from the call_fn, and adds the losses."""
|
||||
if layer._expects_training_arg: # pylint: disable=protected-access
|
||||
training = kwargs.pop('training', None)
|
||||
if training is None:
|
||||
training = K.learning_phase()
|
||||
outputs, losses = tf_utils.smart_cond(
|
||||
training,
|
||||
lambda: call_fn(inputs, training=True),
|
||||
lambda: call_fn(inputs, training=False))
|
||||
else:
|
||||
outputs, losses = call_fn(inputs)
|
||||
layer.add_loss(losses, inputs)
|
||||
return outputs
|
||||
return wrapped_call
|
428
tensorflow/python/keras/saving/saved_model_experimental.py
Normal file
428
tensorflow/python/keras/saving/saved_model_experimental.py
Normal file
@ -0,0 +1,428 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Deprecated experimental Keras SavedModel implementation."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import six
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras import optimizers
|
||||
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
|
||||
from tensorflow.python.keras.saving import model_from_json
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.utils import mode_keys
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import builder as saved_model_builder
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import model_utils
|
||||
from tensorflow.python.saved_model import save as save_lib
|
||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training.tracking import graph_view
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
# To avoid circular dependencies between keras/engine and keras/saving,
|
||||
# code in keras/saving must delay imports.
|
||||
|
||||
# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
|
||||
# once the issue with copybara is fixed.
|
||||
# pylint:disable=g-inconsistent-quotes
|
||||
metrics_lib = LazyLoader("metrics_lib", globals(),
|
||||
"tensorflow.python.keras.metrics")
|
||||
models_lib = LazyLoader("models_lib", globals(),
|
||||
"tensorflow.python.keras.models")
|
||||
sequential = LazyLoader(
|
||||
"sequential", globals(),
|
||||
"tensorflow.python.keras.engine.sequential")
|
||||
# pylint:enable=g-inconsistent-quotes
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
date=None,
|
||||
instructions=('Please use `model.save(..., save_format="tf")` or '
|
||||
'`tf.keras.models.save_model(..., save_format="tf")`.'))
|
||||
@keras_export('keras.experimental.export_saved_model')
|
||||
def export_saved_model(model,
|
||||
saved_model_path,
|
||||
custom_objects=None,
|
||||
as_text=False,
|
||||
input_signature=None,
|
||||
serving_only=False):
|
||||
"""Exports a `tf.keras.Model` as a Tensorflow SavedModel.
|
||||
|
||||
Note that at this time, subclassed models can only be saved using
|
||||
`serving_only=True`.
|
||||
|
||||
The exported `SavedModel` is a standalone serialization of Tensorflow objects,
|
||||
and is supported by TF language APIs and the Tensorflow Serving system.
|
||||
To load the model, use the function
|
||||
`tf.keras.experimental.load_from_saved_model`.
|
||||
|
||||
The `SavedModel` contains:
|
||||
|
||||
1. a checkpoint containing the model weights.
|
||||
2. a `SavedModel` proto containing the Tensorflow backend graph. Separate
|
||||
graphs are saved for prediction (serving), train, and evaluation. If
|
||||
the model has not been compiled, then only the graph computing predictions
|
||||
will be exported.
|
||||
3. the model's json config. If the model is subclassed, this will only be
|
||||
included if the model's `get_config()` method is overwritten.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
|
||||
# Create a tf.keras model.
|
||||
model = tf.keras.Sequential()
|
||||
model.add(tf.keras.layers.Dense(1, input_shape=[10]))
|
||||
model.summary()
|
||||
|
||||
# Save the tf.keras model in the SavedModel format.
|
||||
path = '/tmp/simple_keras_model'
|
||||
tf.keras.experimental.export_saved_model(model, path)
|
||||
|
||||
# Load the saved keras model back.
|
||||
new_model = tf.keras.experimental.load_from_saved_model(path)
|
||||
new_model.summary()
|
||||
```
|
||||
|
||||
Args:
|
||||
model: A `tf.keras.Model` to be saved. If the model is subclassed, the flag
|
||||
`serving_only` must be set to True.
|
||||
saved_model_path: a string specifying the path to the SavedModel directory.
|
||||
custom_objects: Optional dictionary mapping string names to custom classes
|
||||
or functions (e.g. custom loss functions).
|
||||
as_text: bool, `False` by default. Whether to write the `SavedModel` proto
|
||||
in text format. Currently unavailable in serving-only mode.
|
||||
input_signature: A possibly nested sequence of `tf.TensorSpec` objects, used
|
||||
to specify the expected model inputs. See `tf.function` for more details.
|
||||
serving_only: bool, `False` by default. When this is true, only the
|
||||
prediction graph is saved.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the model is a subclassed model, and serving_only is
|
||||
False.
|
||||
ValueError: If the input signature cannot be inferred from the model.
|
||||
AssertionError: If the SavedModel directory already exists and isn't empty.
|
||||
"""
|
||||
if serving_only:
|
||||
save_lib.save(
|
||||
model,
|
||||
saved_model_path,
|
||||
signatures=saving_utils.trace_model_call(model, input_signature))
|
||||
else:
|
||||
_save_v1_format(model, saved_model_path, custom_objects, as_text,
|
||||
input_signature)
|
||||
|
||||
try:
|
||||
_export_model_json(model, saved_model_path)
|
||||
except NotImplementedError:
|
||||
logging.warning('Skipped saving model JSON, subclassed model does not have '
|
||||
'get_config() defined.')
|
||||
|
||||
|
||||
def _export_model_json(model, saved_model_path):
|
||||
"""Saves model configuration as a json string under assets folder."""
|
||||
model_json = model.to_json()
|
||||
model_json_filepath = os.path.join(
|
||||
saved_model_utils.get_or_create_assets_dir(saved_model_path),
|
||||
compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
|
||||
file_io.write_string_to_file(model_json_filepath, model_json)
|
||||
|
||||
|
||||
def _export_model_variables(model, saved_model_path):
|
||||
"""Saves model weights in checkpoint format under variables folder."""
|
||||
saved_model_utils.get_or_create_variables_dir(saved_model_path)
|
||||
checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path)
|
||||
model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
|
||||
return checkpoint_prefix
|
||||
|
||||
|
||||
def _save_v1_format(model, path, custom_objects, as_text, input_signature):
|
||||
"""Exports model to v1 SavedModel format."""
|
||||
if not model._is_graph_network: # pylint: disable=protected-access
|
||||
if isinstance(model, sequential.Sequential):
|
||||
# If input shape is not directly set in the model, the exported model
|
||||
# will infer the expected shapes of the input from the model.
|
||||
if not model.built:
|
||||
raise ValueError('Weights for sequential model have not yet been '
|
||||
'created. Weights are created when the Model is first '
|
||||
'called on inputs or `build()` is called with an '
|
||||
'`input_shape`, or the first layer in the model has '
|
||||
'`input_shape` during construction.')
|
||||
# TODO(kathywu): Build the model with input_signature to create the
|
||||
# weights before _export_model_variables().
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'Subclassed models can only be exported for serving. Please set '
|
||||
'argument serving_only=True.')
|
||||
|
||||
builder = saved_model_builder._SavedModelBuilder(path) # pylint: disable=protected-access
|
||||
|
||||
# Manually save variables to export them in an object-based checkpoint. This
|
||||
# skips the `builder.add_meta_graph_and_variables()` step, which saves a
|
||||
# named-based checkpoint.
|
||||
# TODO(b/113134168): Add fn to Builder to save with object-based saver.
|
||||
# TODO(b/113178242): This should only export the model json structure. Only
|
||||
# one save is needed once the weights can be copied from the model to clone.
|
||||
checkpoint_path = _export_model_variables(model, path)
|
||||
|
||||
# Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that
|
||||
# Keras models and `Estimator`s are exported with the same format.
|
||||
# Every time a mode is exported, the code checks to see if new variables have
|
||||
# been created (e.g. optimizer slot variables). If that is the case, the
|
||||
# checkpoint is re-saved to include the new variables.
|
||||
export_args = {'builder': builder,
|
||||
'model': model,
|
||||
'custom_objects': custom_objects,
|
||||
'checkpoint_path': checkpoint_path,
|
||||
'input_signature': input_signature}
|
||||
|
||||
has_saved_vars = False
|
||||
if model.optimizer:
|
||||
if isinstance(model.optimizer, (optimizers.TFOptimizer,
|
||||
optimizer_v2.OptimizerV2)):
|
||||
_export_mode(mode_keys.ModeKeys.TRAIN, has_saved_vars, **export_args)
|
||||
has_saved_vars = True
|
||||
_export_mode(mode_keys.ModeKeys.TEST, has_saved_vars, **export_args)
|
||||
else:
|
||||
logging.warning(
|
||||
'Model was compiled with an optimizer, but the optimizer is not from '
|
||||
'`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving '
|
||||
'graph was exported. The train and evaluate graphs were not added to '
|
||||
'the SavedModel.')
|
||||
_export_mode(mode_keys.ModeKeys.PREDICT, has_saved_vars, **export_args)
|
||||
|
||||
builder.save(as_text)
|
||||
|
||||
|
||||
def _get_var_list(model):
|
||||
"""Returns list of all checkpointed saveable objects in the model."""
|
||||
var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph()
|
||||
return var_list
|
||||
|
||||
|
||||
def create_placeholder(spec):
|
||||
return K.placeholder(shape=spec.shape, dtype=spec.dtype, name=spec.name)
|
||||
|
||||
|
||||
def _export_mode(
|
||||
mode, has_saved_vars, builder, model, custom_objects, checkpoint_path,
|
||||
input_signature):
|
||||
"""Exports a model, and optionally saves new vars from the clone model.
|
||||
|
||||
Args:
|
||||
mode: A `tf.estimator.ModeKeys` string.
|
||||
has_saved_vars: A `boolean` indicating whether the SavedModel has already
|
||||
exported variables.
|
||||
builder: A `SavedModelBuilder` object.
|
||||
model: A `tf.keras.Model` object.
|
||||
custom_objects: A dictionary mapping string names to custom classes
|
||||
or functions.
|
||||
checkpoint_path: String path to checkpoint.
|
||||
input_signature: Nested TensorSpec containing the expected inputs. Can be
|
||||
`None`, in which case the signature will be inferred from the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the train/eval mode is being exported, but the model does
|
||||
not have an optimizer.
|
||||
"""
|
||||
compile_clone = (mode != mode_keys.ModeKeys.PREDICT)
|
||||
if compile_clone and not model.optimizer:
|
||||
raise ValueError(
|
||||
'Model does not have an optimizer. Cannot export mode %s' % mode)
|
||||
|
||||
model_graph = ops.get_default_graph()
|
||||
with ops.Graph().as_default() as g, K.learning_phase_scope(
|
||||
mode == mode_keys.ModeKeys.TRAIN):
|
||||
|
||||
if input_signature is None:
|
||||
input_tensors = None
|
||||
else:
|
||||
input_tensors = nest.map_structure(create_placeholder, input_signature)
|
||||
|
||||
# Clone the model into blank graph. This will create placeholders for inputs
|
||||
# and targets.
|
||||
clone = models_lib.clone_and_build_model(
|
||||
model, input_tensors=input_tensors, custom_objects=custom_objects,
|
||||
compile_clone=compile_clone)
|
||||
|
||||
# Make sure that iterations variable is added to the global step collection,
|
||||
# to ensure that, when the SavedModel graph is loaded, the iterations
|
||||
# variable is returned by `tf.compat.v1.train.get_global_step()`. This is
|
||||
# required for compatibility with the SavedModelEstimator.
|
||||
if compile_clone:
|
||||
g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
|
||||
|
||||
# Extract update and train ops from train/test/predict functions.
|
||||
train_op = None
|
||||
if mode == mode_keys.ModeKeys.TRAIN:
|
||||
clone._make_train_function() # pylint: disable=protected-access
|
||||
train_op = clone.train_function.updates_op
|
||||
elif mode == mode_keys.ModeKeys.TEST:
|
||||
clone._make_test_function() # pylint: disable=protected-access
|
||||
else:
|
||||
clone._make_predict_function() # pylint: disable=protected-access
|
||||
g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
|
||||
|
||||
with session.Session().as_default():
|
||||
clone_var_list = _get_var_list(clone)
|
||||
if has_saved_vars:
|
||||
# Confirm all variables in the clone have an entry in the checkpoint.
|
||||
status = clone.load_weights(checkpoint_path)
|
||||
status.assert_existing_objects_matched()
|
||||
else:
|
||||
# Confirm that variables between the clone and model match up exactly,
|
||||
# not counting optimizer objects. Optimizer objects are ignored because
|
||||
# if the model has not trained, the slot variables will not have been
|
||||
# created yet.
|
||||
# TODO(b/113179535): Replace with trackable equivalence.
|
||||
_assert_same_non_optimizer_objects(model, model_graph, clone, g)
|
||||
|
||||
# TODO(b/113178242): Use value transfer for trackable objects.
|
||||
clone.load_weights(checkpoint_path)
|
||||
|
||||
# Add graph and variables to SavedModel.
|
||||
# TODO(b/113134168): Switch to add_meta_graph_and_variables.
|
||||
clone.save_weights(checkpoint_path, save_format='tf', overwrite=True)
|
||||
builder._has_saved_variables = True # pylint: disable=protected-access
|
||||
|
||||
# Add graph to the SavedModel builder.
|
||||
builder.add_meta_graph(
|
||||
model_utils.EXPORT_TAG_MAP[mode],
|
||||
signature_def_map=_create_signature_def_map(clone, mode),
|
||||
saver=saver_lib.Saver(
|
||||
clone_var_list,
|
||||
# Allow saving Models with no variables. This is somewhat odd, but
|
||||
# it's not necessarily a bug.
|
||||
allow_empty=True),
|
||||
init_op=variables.local_variables_initializer(),
|
||||
train_op=train_op)
|
||||
return None
|
||||
|
||||
|
||||
def _create_signature_def_map(model, mode):
|
||||
"""Creates a SignatureDef map from a Keras model."""
|
||||
inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)}
|
||||
if model.optimizer:
|
||||
targets_dict = {x.name.split(':')[0]: x
|
||||
for x in model._targets if x is not None} # pylint: disable=protected-access
|
||||
inputs_dict.update(targets_dict)
|
||||
outputs_dict = {name: x
|
||||
for name, x in zip(model.output_names, model.outputs)}
|
||||
metrics = saving_utils.extract_model_metrics(model)
|
||||
|
||||
# Add metric variables to the `LOCAL_VARIABLES` collection. Metric variables
|
||||
# are by default not added to any collections. We are doing this here, so
|
||||
# that metric variables get initialized.
|
||||
local_vars = set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))
|
||||
vars_to_add = set()
|
||||
if metrics is not None:
|
||||
for key, value in six.iteritems(metrics):
|
||||
if isinstance(value, metrics_lib.Metric):
|
||||
vars_to_add.update(value.variables)
|
||||
# Convert Metric instances to (value_tensor, update_op) tuple.
|
||||
metrics[key] = (value.result(), value.updates[0])
|
||||
# Remove variables that are in the local variables collection already.
|
||||
vars_to_add = vars_to_add.difference(local_vars)
|
||||
for v in vars_to_add:
|
||||
ops.add_to_collection(ops.GraphKeys.LOCAL_VARIABLES, v)
|
||||
|
||||
export_outputs = model_utils.export_outputs_for_mode(
|
||||
mode,
|
||||
predictions=outputs_dict,
|
||||
loss=model.total_loss if model.optimizer else None,
|
||||
metrics=metrics)
|
||||
return model_utils.build_all_signature_defs(
|
||||
inputs_dict,
|
||||
export_outputs=export_outputs,
|
||||
serving_only=(mode == mode_keys.ModeKeys.PREDICT))
|
||||
|
||||
|
||||
def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph): # pylint: disable=unused-argument
|
||||
"""Asserts model and clone contain the same trackable objects."""
|
||||
|
||||
# TODO(fchollet, kathywu): make sure this works in eager mode.
|
||||
return True
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
date=None,
|
||||
instructions=('The experimental save and load functions have been '
|
||||
'deprecated. Please switch to `tf.keras.models.load_model`.'))
|
||||
@keras_export('keras.experimental.load_from_saved_model')
|
||||
def load_from_saved_model(saved_model_path, custom_objects=None):
|
||||
"""Loads a keras Model from a SavedModel created by `export_saved_model()`.
|
||||
|
||||
This function reinstantiates model state by:
|
||||
1) loading model topology from json (this will eventually come
|
||||
from metagraph).
|
||||
2) loading model weights from checkpoint.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
|
||||
# Create a tf.keras model.
|
||||
model = tf.keras.Sequential()
|
||||
model.add(tf.keras.layers.Dense(1, input_shape=[10]))
|
||||
model.summary()
|
||||
|
||||
# Save the tf.keras model in the SavedModel format.
|
||||
path = '/tmp/simple_keras_model'
|
||||
tf.keras.experimental.export_saved_model(model, path)
|
||||
|
||||
# Load the saved keras model back.
|
||||
new_model = tf.keras.experimental.load_from_saved_model(path)
|
||||
new_model.summary()
|
||||
```
|
||||
|
||||
Args:
|
||||
saved_model_path: a string specifying the path to an existing SavedModel.
|
||||
custom_objects: Optional dictionary mapping names
|
||||
(strings) to custom classes or functions to be
|
||||
considered during deserialization.
|
||||
|
||||
Returns:
|
||||
a keras.Model instance.
|
||||
"""
|
||||
# restore model topology from json string
|
||||
model_json_filepath = os.path.join(
|
||||
compat.as_bytes(saved_model_path),
|
||||
compat.as_bytes(constants.ASSETS_DIRECTORY),
|
||||
compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
|
||||
model_json = file_io.read_file_to_string(model_json_filepath)
|
||||
model = model_from_json(model_json, custom_objects=custom_objects)
|
||||
|
||||
# restore model weights
|
||||
checkpoint_prefix = os.path.join(
|
||||
compat.as_text(saved_model_path),
|
||||
compat.as_text(constants.VARIABLES_DIRECTORY),
|
||||
compat.as_text(constants.VARIABLES_FILENAME))
|
||||
model.load_weights(checkpoint_prefix)
|
||||
return model
|
@ -28,28 +28,19 @@ from tensorflow.python import keras
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import regularizers
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import training as model_lib
|
||||
from tensorflow.python.keras.optimizer_v2 import adadelta
|
||||
from tensorflow.python.keras.saving import saved_model as keras_saved_model
|
||||
from tensorflow.python.keras.saving import saved_model_experimental as keras_saved_model
|
||||
from tensorflow.python.keras.utils import mode_keys
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import load as tf_load
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import model_utils
|
||||
from tensorflow.python.saved_model import save as tf_save
|
||||
from tensorflow.python.training import training as training_module
|
||||
|
||||
|
||||
@ -552,252 +543,5 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllClose(ref_predict, predictions, atol=1e-05)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class TestModelSavingAndLoadingV2(keras_parameterized.TestCase):
|
||||
|
||||
def _save_model_dir(self, dirname='saved_model'):
|
||||
temp_dir = self.get_temp_dir()
|
||||
self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
|
||||
return os.path.join(temp_dir, dirname)
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_model_save_and_load(self):
|
||||
input_arr = np.random.random((1, 3)).astype(np.float32)
|
||||
target_arr = np.random.random((1, 4)).astype(np.float32)
|
||||
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
model.layers[-1].activity_regularizer = regularizers.get('l2')
|
||||
model.activity_regularizer = regularizers.get('l2')
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer='rmsprop')
|
||||
model.train_on_batch(input_arr, target_arr)
|
||||
|
||||
def callable_loss():
|
||||
return math_ops.reduce_sum(model.weights[0])
|
||||
model.add_loss(callable_loss)
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
|
||||
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
self.assertAllClose(self.evaluate(model.weights),
|
||||
self.evaluate(loaded.weights))
|
||||
|
||||
input_arr = constant_op.constant(
|
||||
np.random.random((1, 3)).astype(np.float32))
|
||||
self.assertAllClose(self.evaluate(model(input_arr)),
|
||||
self.evaluate(loaded(input_arr)))
|
||||
# Validate losses. The order of conditional losses may change between the
|
||||
# model and loaded model, so sort the losses first.
|
||||
if context.executing_eagerly():
|
||||
self.assertAllClose(sorted(self.evaluate(model.losses)),
|
||||
sorted(self.evaluate(loaded.losses)))
|
||||
else:
|
||||
self.assertAllClose(self.evaluate(model.get_losses_for(None)),
|
||||
self.evaluate(loaded.get_losses_for(None)))
|
||||
self.assertAllClose(
|
||||
sorted(self.evaluate(model.get_losses_for(input_arr))),
|
||||
sorted(self.evaluate(loaded.get_losses_for(input_arr))))
|
||||
|
||||
def test_trainable_weights(self):
|
||||
layer = keras.layers.Dense(4, name='custom_layer')
|
||||
layer.build([3,])
|
||||
layer.add_weight(
|
||||
'extra_weight', shape=[],
|
||||
initializer=init_ops.constant_initializer(11),
|
||||
trainable=True)
|
||||
layer.add_weight(
|
||||
'extra_weight_2', shape=[],
|
||||
initializer=init_ops.constant_initializer(12),
|
||||
trainable=False)
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
self.evaluate(variables.variables_initializer(layer.variables))
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
|
||||
equal_attrs = ['name', '_expects_training_arg', 'trainable']
|
||||
for attr in equal_attrs:
|
||||
self.assertEqual(getattr(layer, attr), getattr(loaded, attr))
|
||||
|
||||
all_close = ['weights', 'trainable_weights', 'non_trainable_weights']
|
||||
for attr in all_close:
|
||||
self.assertAllClose(self.evaluate(getattr(layer, attr)),
|
||||
self.evaluate(getattr(loaded, attr)))
|
||||
|
||||
def test_maintains_losses(self):
|
||||
"""Tests that the layer losses do not change before and after export."""
|
||||
|
||||
class LayerWithLoss(keras.layers.Layer):
|
||||
|
||||
def call(self, inputs):
|
||||
self.add_loss(math_ops.reduce_sum(inputs), inputs)
|
||||
return inputs
|
||||
|
||||
model = keras.models.Sequential([LayerWithLoss()])
|
||||
model.compile(
|
||||
loss='mse',
|
||||
optimizer='rmsprop')
|
||||
input_arr = np.random.random((1, 3)).astype(np.float32)
|
||||
target_arr = np.random.random((1, 3)).astype(np.float32)
|
||||
|
||||
# Test that symbolic losses are maintained (train_on_batch saves symbolic
|
||||
# losses.)
|
||||
model.train_on_batch(input_arr, target_arr)
|
||||
previous_losses = model.losses[:]
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
self.assertAllEqual(previous_losses, model.losses)
|
||||
|
||||
if context.executing_eagerly():
|
||||
# Test that eager losses are maintained.
|
||||
model(input_arr) # Calls model eagerly, creating eager losses.
|
||||
previous_losses = model.losses[:]
|
||||
tf_save.save(model, saved_model_dir)
|
||||
self.assertAllEqual(previous_losses, model.losses)
|
||||
|
||||
def test_layer_with_learning_phase(self):
|
||||
layer = LayerWithLearningPhase()
|
||||
layer.build([None, None])
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
|
||||
input_arr = array_ops.ones((4, 3))
|
||||
|
||||
# Run the layer, and use the keras backend learing phase
|
||||
keras.backend.set_learning_phase(0)
|
||||
self.assertAllEqual(input_arr, loaded(input_arr))
|
||||
keras.backend.set_learning_phase(1)
|
||||
self.assertAllEqual(array_ops.zeros((4, 3)), loaded(input_arr))
|
||||
|
||||
# Run the layer while explicitly setting the training argument
|
||||
self.assertAllEqual(
|
||||
input_arr, loaded(input_arr, training=constant_op.constant(False)))
|
||||
self.assertAllEqual(
|
||||
array_ops.zeros((4, 3)),
|
||||
loaded(input_arr, training=constant_op.constant(True)))
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_standard_loader(self):
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
model.activity_regularizer = regularizers.get('l2')
|
||||
def eager_loss():
|
||||
return math_ops.reduce_sum(model.weights[0])
|
||||
model.add_loss(eager_loss)
|
||||
|
||||
# Call predict to ensure that all layers are built and inputs are set.
|
||||
model.predict(np.random.random((1, 3)))
|
||||
saved_model_dir = self._save_model_dir()
|
||||
|
||||
tf_save.save(model, saved_model_dir)
|
||||
|
||||
loaded = tf_load.load(saved_model_dir)
|
||||
self.evaluate(variables.variables_initializer(loaded.variables))
|
||||
all_close = ['variables', 'trainable_variables',
|
||||
'non_trainable_variables']
|
||||
for attr in all_close:
|
||||
self.assertAllClose(self.evaluate(getattr(model, attr)),
|
||||
self.evaluate(getattr(loaded.keras_api, attr)))
|
||||
self.assertLen(loaded.regularization_losses, 1)
|
||||
expected_layers = len(model.layers)
|
||||
self.assertEqual(expected_layers, len(loaded.keras_api.layers))
|
||||
input_arr = array_ops.ones((4, 3))
|
||||
self.assertAllClose(self.evaluate(model(input_arr)),
|
||||
self.evaluate(loaded(input_arr)))
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
def test_compiled_model(self):
|
||||
input_arr = np.random.random((1, 3))
|
||||
target_arr = np.random.random((1, 4))
|
||||
|
||||
model = testing_utils.get_small_mlp(1, 4, input_dim=3)
|
||||
expected_predict = model.predict(input_arr)
|
||||
|
||||
# Compile and save model.
|
||||
model.compile('rmsprop', 'mse')
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(model, saved_model_dir)
|
||||
|
||||
# TODO(b/134519980): Issue with model.fit if the model call function uses
|
||||
# a tf.function (Graph mode only).
|
||||
with context.eager_mode():
|
||||
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
|
||||
actual_predict = loaded.predict(input_arr)
|
||||
self.assertAllClose(expected_predict, actual_predict)
|
||||
|
||||
loss_before = loaded.evaluate(input_arr, target_arr)
|
||||
loaded.fit(input_arr, target_arr)
|
||||
loss_after = loaded.evaluate(input_arr, target_arr)
|
||||
self.assertLess(loss_after, loss_before)
|
||||
predict = loaded.predict(input_arr)
|
||||
|
||||
ckpt_path = os.path.join(self.get_temp_dir(), 'weights')
|
||||
loaded.save_weights(ckpt_path)
|
||||
|
||||
# Ensure that the checkpoint is compatible with the original model.
|
||||
model.load_weights(ckpt_path)
|
||||
self.assertAllClose(predict, model.predict(input_arr))
|
||||
|
||||
def test_metadata_input_spec(self):
|
||||
class LayerWithNestedSpec(keras.layers.Layer):
|
||||
|
||||
def __init__(self):
|
||||
super(LayerWithNestedSpec, self).__init__()
|
||||
self.input_spec = {
|
||||
'a': keras.layers.InputSpec(max_ndim=3, axes={-1: 2}),
|
||||
'b': keras.layers.InputSpec(shape=(None, 2, 3), dtype='float16')}
|
||||
|
||||
layer = LayerWithNestedSpec()
|
||||
saved_model_dir = self._save_model_dir()
|
||||
tf_save.save(layer, saved_model_dir)
|
||||
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
|
||||
self.assertEqual(3, loaded.input_spec['a'].max_ndim)
|
||||
self.assertEqual({-1: 2}, loaded.input_spec['a'].axes)
|
||||
self.assertAllEqual([None, 2, 3], loaded.input_spec['b'].shape)
|
||||
self.assertEqual('float16', loaded.input_spec['b'].dtype)
|
||||
|
||||
def test_multi_input_model(self):
|
||||
input_1 = keras.layers.Input(shape=(3,))
|
||||
input_2 = keras.layers.Input(shape=(5,))
|
||||
model = keras.Model([input_1, input_2], [input_1, input_2])
|
||||
saved_model_dir = self._save_model_dir()
|
||||
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
|
||||
input_arr_1 = np.random.random((1, 3)).astype('float32')
|
||||
input_arr_2 = np.random.random((1, 5)).astype('float32')
|
||||
|
||||
outputs = loaded([input_arr_1, input_arr_2])
|
||||
self.assertAllEqual(input_arr_1, outputs[0])
|
||||
self.assertAllEqual(input_arr_2, outputs[1])
|
||||
|
||||
def test_revived_sequential(self):
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.Dense(5, input_shape=(3,),
|
||||
kernel_regularizer=regularizers.get('l2')))
|
||||
model.add(keras.layers.Dense(2, kernel_regularizer=regularizers.get('l2')))
|
||||
|
||||
self.evaluate(variables.variables_initializer(model.variables))
|
||||
|
||||
saved_model_dir = self._save_model_dir()
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded = keras_saved_model.load_from_saved_model_v2(saved_model_dir)
|
||||
|
||||
self.assertLen(loaded.layers, 2)
|
||||
self.assertLen(loaded.losses, 2)
|
||||
|
||||
loaded.pop()
|
||||
|
||||
self.assertLen(loaded.layers, 1)
|
||||
self.assertLen(loaded.losses, 1)
|
||||
|
||||
loaded.add(keras.layers.Dense(2, kernel_regularizer=regularizers.get('l2')))
|
||||
|
||||
self.assertLen(loaded.layers, 2)
|
||||
self.assertLen(loaded.losses, 2)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user