Save Keras metadata in a separate folder and raise deprecation warnings when loading a SavedModel with tf.saved_model.save().

PiperOrigin-RevId: 338374188
Change-Id: I884ca90e9e3ed75e3b091dff6acb67c8db0d7e7b
This commit is contained in:
A. Unique TensorFlower 2020-10-21 17:32:34 -07:00 committed by TensorFlower Gardener
parent d410ac999b
commit ec8ef1a4f2
6 changed files with 17 additions and 117 deletions

View File

@ -49,7 +49,6 @@ py_library(
deps = [
"//tensorflow/python:lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
"//tensorflow/python:saver",
"//tensorflow/python:tensor_spec",
"//tensorflow/python/eager:def_function",

View File

@ -26,7 +26,3 @@ KERAS_ATTR = 'keras_api'
# Keys for the serialization cache.
# Maps to the keras serialization dict {Layer --> SerializedAttributes object}
KERAS_CACHE_KEY = 'keras_serialized_attributes'
# Name of Keras metadata file stored in the SavedModel.
SAVED_METADATA_PATH = 'keras_metadata.pb'

View File

@ -17,12 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
import types
from google.protobuf import message
from tensorflow.core.framework import versions_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import function as defun
@ -41,7 +38,6 @@ from tensorflow.python.keras.saving.saved_model.serialized_attributes import Com
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import metrics_utils
from tensorflow.python.keras.utils.generic_utils import LazyLoader
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import load as tf_load
from tensorflow.python.saved_model import loader_impl
@ -125,26 +121,13 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
# TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
# TODO(kathywu): Add code to load from objects that contain all endpoints
# Look for metadata file or parse the SavedModel
# The Keras metadata file is not yet saved, so create it from the SavedModel.
metadata = saved_metadata_pb2.SavedMetadata()
meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0]
object_graph_def = meta_graph_def.object_graph_def
path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH)
if gfile.Exists(path_to_metadata_pb):
try:
with gfile.GFile(path_to_metadata_pb, 'rb') as f:
file_content = f.read()
metadata.ParseFromString(file_content)
except message.DecodeError as e:
raise IOError('Cannot parse keras metadata {}: {}.'
.format(path_to_metadata_pb, str(e)))
else:
logging.warning('SavedModel saved prior to TF 2.4 detected when loading '
'Keras model. Please ensure that you are saving the model '
'with model.save() or tf.keras.models.save_model(), *NOT* '
'tf.saved_model.save(). To confirm, there should be a file '
'named "keras_metadata.pb" in the SavedModel directory.')
_read_legacy_metadata(object_graph_def, metadata)
# TODO(kathywu): When the keras metadata file is saved, load it directly
# instead of calling the _read_legacy_metadata function.
_read_legacy_metadata(object_graph_def, metadata)
if not metadata.nodes:
# When there are no Keras objects, return the results from the core loader

View File

@ -18,21 +18,15 @@ from __future__ import division
from __future__ import print_function
import os
from tensorflow.core.framework import versions_pb2
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.protobuf import saved_metadata_pb2
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.saving.saved_model import constants
from tensorflow.python.keras.saving.saved_model import save_impl
from tensorflow.python.keras.saving.saved_model import utils
from tensorflow.python.keras.utils.generic_utils import LazyLoader
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import save as save_lib
# To avoid circular dependencies between keras/engine and keras/saving,
# code in keras/saving must delay imports.
@ -92,39 +86,7 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
# we use the default replica context here.
with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access
with utils.keras_option_scope(save_traces):
saved_nodes, node_paths = save_lib.save_and_return_nodes(
model, filepath, signatures, options)
# Save all metadata to a separate file in the SavedModel directory.
metadata = generate_keras_metadata(saved_nodes, node_paths)
with gfile.GFile(
os.path.join(filepath, constants.SAVED_METADATA_PATH), "wb") as w:
w.write(metadata.SerializeToString(deterministic=True))
save_lib.save(model, filepath, signatures, options)
if not include_optimizer:
model.optimizer = orig_optimizer
def generate_keras_metadata(saved_nodes, node_paths):
"""Constructs a KerasMetadata proto with the metadata of each keras object."""
metadata = saved_metadata_pb2.SavedMetadata()
for node_id, node in enumerate(saved_nodes):
if isinstance(node, base_layer.Layer):
path = node_paths[node]
if not path:
node_path = "root"
else:
node_path = "root.{}".format(
".".join([ref.name for ref in path]))
metadata.nodes.add(
node_id=node_id,
node_path=node_path,
version=versions_pb2.VersionDef(
producer=1, min_consumer=1, bad_consumers=[]),
identifier=node._object_identifier, # pylint: disable=protected-access
metadata=node._tracking_metadata) # pylint: disable=protected-access
return metadata

View File

@ -180,9 +180,8 @@ class _SaveableView(object):
"""
self.options = options
self.checkpoint_view = checkpoint_view
trackable_objects, path_to_root, node_ids, slot_variables = (
self.checkpoint_view.objects_ids_and_slot_variables_and_paths())
self.node_paths = path_to_root
trackable_objects, node_ids, slot_variables = (
self.checkpoint_view.objects_ids_and_slot_variables())
self.nodes = trackable_objects
self.node_ids = node_ids
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
@ -1022,30 +1021,6 @@ def save(obj, export_dir, signatures=None, options=None):
May not be called from within a function body.
@end_compatibility
"""
save_and_return_nodes(obj, export_dir, signatures, options,
raise_metadata_warning=True)
def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
raise_metadata_warning=False):
"""Saves a SavedModel while returning all saved nodes and their paths.
Please see `tf.saved_model.save` for details.
Args:
obj: A trackable object to export.
export_dir: A directory in which to write the SavedModel.
signatures: A function or dictionary of functions to save in the SavedModel
as signatures.
options: `tf.saved_model.SaveOptions` object for configuring save options.
raise_metadata_warning: Whether to raise the metadata warning. This arg will
be removed in TF 2.5.
Returns:
A tuple of (a list of saved nodes in the order they are serialized to the
`SavedObjectGraph`, dictionary mapping nodes to one possible path from
the root node to the key node)
"""
options = options or save_options.SaveOptions()
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
# compatible (no sessions) and share it with this export API rather than
@ -1053,9 +1028,8 @@ def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
saved_model = saved_model_pb2.SavedModel()
meta_graph_def = saved_model.meta_graphs.add()
_, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
_build_meta_graph(obj, signatures, options, meta_graph_def,
raise_metadata_warning))
_, exported_graph, object_saver, asset_info = _build_meta_graph(
obj, signatures, options, meta_graph_def)
saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
# Write the checkpoint, copy assets into the assets directory, and write out
@ -1095,8 +1069,6 @@ def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
# constants in the saved graph.
ops.dismantle_graph(exported_graph)
return saved_nodes, node_paths
def export_meta_graph(obj, filename, signatures=None, options=None):
"""Exports the MetaGraph proto of the `obj` to a file.
@ -1123,7 +1095,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None):
"""
options = options or save_options.SaveOptions()
export_dir = os.path.dirname(filename)
meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph(
meta_graph_def, exported_graph, _, _ = _build_meta_graph(
obj, signatures, options)
file_io.atomic_write_string_to_file(
@ -1142,8 +1114,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None):
def _build_meta_graph_impl(obj,
signatures,
options,
meta_graph_def=None,
raise_metadata_warning=True):
meta_graph_def=None):
"""Creates a MetaGraph containing the resources and functions of an object."""
if ops.inside_function():
raise AssertionError(
@ -1191,7 +1162,7 @@ def _build_meta_graph_impl(obj,
saveable_view, asset_info.asset_index)
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
if saved_object_metadata and raise_metadata_warning:
if saved_object_metadata:
tf_logging.warn(
'FOR KERAS USERS: The object that you are saving contains one or more '
'Keras models or layers. If you are loading the SavedModel with '
@ -1207,15 +1178,13 @@ def _build_meta_graph_impl(obj,
'metadta field will be deprecated soon, so please move the metadata to '
'a different file.')
return (meta_graph_def, exported_graph, object_saver, asset_info,
saveable_view.nodes, saveable_view.node_paths)
return (meta_graph_def, exported_graph, object_saver, asset_info)
def _build_meta_graph(obj,
signatures,
options,
meta_graph_def=None,
raise_metadata_warning=True):
meta_graph_def=None):
"""Creates a MetaGraph under a save context.
Args:
@ -1228,8 +1197,6 @@ def _build_meta_graph(obj,
options: `tf.saved_model.SaveOptions` object that specifies options for
saving.
meta_graph_def: Optional, the MetaGraphDef proto fill.
raise_metadata_warning: Whether to raise a warning when user objects contain
non-empty metadata.
Raises:
AssertionError: If `export_meta_graph` is executing inside a `tf.function`.
@ -1243,5 +1210,4 @@ def _build_meta_graph(obj,
"""
with save_context.save_context(options):
return _build_meta_graph_impl(obj, signatures, options, meta_graph_def,
raise_metadata_warning)
return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)

View File

@ -430,7 +430,7 @@ class ObjectGraphView(object):
name=base.OBJECT_GRAPH_PROTO_KEY))
return named_saveable_objects
def objects_ids_and_slot_variables_and_paths(self):
def objects_ids_and_slot_variables(self):
"""Traverse the object graph and list all accessible objects.
Looks for `Trackable` objects which are dependencies of
@ -439,8 +439,7 @@ class ObjectGraphView(object):
(i.e. if they would be saved with a checkpoint).
Returns:
A tuple of (trackable objects, paths from root for each object,
object -> node id, slot variables)
A tuple of (trackable objects, object -> node id, slot variables)
"""
trackable_objects, path_to_root = self._breadth_first_traversal()
object_names = object_identity.ObjectIdentityDictionary()
@ -453,11 +452,6 @@ class ObjectGraphView(object):
trackable_objects=trackable_objects,
node_ids=node_ids,
object_names=object_names)
return trackable_objects, path_to_root, node_ids, slot_variables
def objects_ids_and_slot_variables(self):
trackable_objects, _, node_ids, slot_variables = (
self.objects_ids_and_slot_variables_and_paths())
return trackable_objects, node_ids, slot_variables
def list_objects(self):