From 8366f2ecea4c3b3a4b2e0114af06992f5af5bc36 Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Wed, 21 Oct 2020 16:03:39 -0700 Subject: [PATCH] Save Keras metadata in a separate folder and raise deprecation warnings when loading a SavedModel with tf.saved_model.save(). PiperOrigin-RevId: 338359077 Change-Id: I93d8c345efb323cd8d4fd1fda4c8e5e86b37d620 --- tensorflow/python/keras/saving/BUILD | 1 + .../keras/saving/saved_model/constants.py | 4 ++ .../python/keras/saving/saved_model/load.py | 25 +++++++-- .../python/keras/saving/saved_model/save.py | 40 +++++++++++++- tensorflow/python/saved_model/save.py | 54 +++++++++++++++---- .../python/training/tracking/graph_view.py | 10 +++- 6 files changed, 117 insertions(+), 17 deletions(-) diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD index 51095c1c75f..7dcc9ae49a6 100644 --- a/tensorflow/python/keras/saving/BUILD +++ b/tensorflow/python/keras/saving/BUILD @@ -49,6 +49,7 @@ py_library( deps = [ "//tensorflow/python:lib", "//tensorflow/python:math_ops", + "//tensorflow/python:platform", "//tensorflow/python:saver", "//tensorflow/python:tensor_spec", "//tensorflow/python/eager:def_function", diff --git a/tensorflow/python/keras/saving/saved_model/constants.py b/tensorflow/python/keras/saving/saved_model/constants.py index 3f1eca9c500..12265e0a3f3 100644 --- a/tensorflow/python/keras/saving/saved_model/constants.py +++ b/tensorflow/python/keras/saving/saved_model/constants.py @@ -26,3 +26,7 @@ KERAS_ATTR = 'keras_api' # Keys for the serialization cache. # Maps to the keras serialization dict {Layer --> SerializedAttributes object} KERAS_CACHE_KEY = 'keras_serialized_attributes' + + +# Name of Keras metadata file stored in the SavedModel. +SAVED_METADATA_PATH = 'keras_metadata.pb' diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index cb6d340ea03..43c1d2bd0d4 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -17,9 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import re import types +from google.protobuf import message + from tensorflow.core.framework import versions_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import function as defun @@ -38,6 +41,7 @@ from tensorflow.python.keras.saving.saved_model.serialized_attributes import Com from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import metrics_utils from tensorflow.python.keras.utils.generic_utils import LazyLoader +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import load as tf_load from tensorflow.python.saved_model import loader_impl @@ -121,13 +125,26 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics. # TODO(kathywu): Add code to load from objects that contain all endpoints - # The Keras metadata file is not yet saved, so create it from the SavedModel. + # Look for metadata file or parse the SavedModel metadata = saved_metadata_pb2.SavedMetadata() meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0] object_graph_def = meta_graph_def.object_graph_def - # TODO(kathywu): When the keras metadata file is saved, load it directly - # instead of calling the _read_legacy_metadata function. - _read_legacy_metadata(object_graph_def, metadata) + path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH) + if gfile.Exists(path_to_metadata_pb): + try: + with gfile.GFile(path_to_metadata_pb, 'rb') as f: + file_content = f.read() + metadata.ParseFromString(file_content) + except message.DecodeError as e: + raise IOError('Cannot parse keras metadata {}: {}.' + .format(path_to_metadata_pb, str(e))) + else: + logging.warning('SavedModel saved prior to TF 2.4 detected when loading ' + 'Keras model. Please ensure that you are saving the model ' + 'with model.save() or tf.keras.models.save_model(), *NOT* ' + 'tf.saved_model.save(). To confirm, there should be a file ' + 'named "keras_metadata.pb" in the SavedModel directory.') + _read_legacy_metadata(object_graph_def, metadata) if not metadata.nodes: # When there are no Keras objects, return the results from the core loader diff --git a/tensorflow/python/keras/saving/saved_model/save.py b/tensorflow/python/keras/saving/saved_model/save.py index 16984a2221b..2ab7ebb60b1 100644 --- a/tensorflow/python/keras/saving/saved_model/save.py +++ b/tensorflow/python/keras/saving/saved_model/save.py @@ -18,15 +18,21 @@ from __future__ import division from __future__ import print_function import os + +from tensorflow.core.framework import versions_pb2 from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.keras import backend as K +from tensorflow.python.keras.protobuf import saved_metadata_pb2 from tensorflow.python.keras.saving import saving_utils +from tensorflow.python.keras.saving.saved_model import constants from tensorflow.python.keras.saving.saved_model import save_impl from tensorflow.python.keras.saving.saved_model import utils from tensorflow.python.keras.utils.generic_utils import LazyLoader from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.platform import gfile from tensorflow.python.saved_model import save as save_lib + # To avoid circular dependencies between keras/engine and keras/saving, # code in keras/saving must delay imports. @@ -86,7 +92,39 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None, # we use the default replica context here. with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access with utils.keras_option_scope(save_traces): - save_lib.save(model, filepath, signatures, options) + saved_nodes, node_paths = save_lib.save_and_return_nodes( + model, filepath, signatures, options) + + # Save all metadata to a separate file in the SavedModel directory. + metadata = generate_keras_metadata(saved_nodes, node_paths) + + with gfile.GFile( + os.path.join(filepath, constants.SAVED_METADATA_PATH), "wb") as w: + w.write(metadata.SerializeToString(deterministic=True)) if not include_optimizer: model.optimizer = orig_optimizer + + +def generate_keras_metadata(saved_nodes, node_paths): + """Constructs a KerasMetadata proto with the metadata of each keras object.""" + metadata = saved_metadata_pb2.SavedMetadata() + + for node_id, node in enumerate(saved_nodes): + if isinstance(node, base_layer.Layer): + path = node_paths[node] + if not path: + node_path = "root" + else: + node_path = "root.{}".format( + ".".join([ref.name for ref in path])) + + metadata.nodes.add( + node_id=node_id, + node_path=node_path, + version=versions_pb2.VersionDef( + producer=1, min_consumer=1, bad_consumers=[]), + identifier=node._object_identifier, # pylint: disable=protected-access + metadata=node._tracking_metadata) # pylint: disable=protected-access + + return metadata diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 27a2867dcd4..76af9885450 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -180,8 +180,9 @@ class _SaveableView(object): """ self.options = options self.checkpoint_view = checkpoint_view - trackable_objects, node_ids, slot_variables = ( - self.checkpoint_view.objects_ids_and_slot_variables()) + trackable_objects, path_to_root, node_ids, slot_variables = ( + self.checkpoint_view.objects_ids_and_slot_variables_and_paths()) + self.node_paths = path_to_root self.nodes = trackable_objects self.node_ids = node_ids self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary() @@ -1021,6 +1022,30 @@ def save(obj, export_dir, signatures=None, options=None): May not be called from within a function body. @end_compatibility """ + save_and_return_nodes(obj, export_dir, signatures, options, + raise_metadata_warning=True) + + +def save_and_return_nodes(obj, export_dir, signatures=None, options=None, + raise_metadata_warning=False): + """Saves a SavedModel while returning all saved nodes and their paths. + + Please see `tf.saved_model.save` for details. + + Args: + obj: A trackable object to export. + export_dir: A directory in which to write the SavedModel. + signatures: A function or dictionary of functions to save in the SavedModel + as signatures. + options: `tf.saved_model.SaveOptions` object for configuring save options. + raise_metadata_warning: Whether to raise the metadata warning. This arg will + be removed in TF 2.5. + + Returns: + A tuple of (a list of saved nodes in the order they are serialized to the + `SavedObjectGraph`, dictionary mapping nodes to one possible path from + the root node to the key node) + """ options = options or save_options.SaveOptions() # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than @@ -1028,8 +1053,9 @@ def save(obj, export_dir, signatures=None, options=None): saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() - _, exported_graph, object_saver, asset_info = _build_meta_graph( - obj, signatures, options, meta_graph_def) + _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = ( + _build_meta_graph(obj, signatures, options, meta_graph_def, + raise_metadata_warning)) saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION # Write the checkpoint, copy assets into the assets directory, and write out @@ -1069,6 +1095,8 @@ def save(obj, export_dir, signatures=None, options=None): # constants in the saved graph. ops.dismantle_graph(exported_graph) + return saved_nodes, node_paths + def export_meta_graph(obj, filename, signatures=None, options=None): """Exports the MetaGraph proto of the `obj` to a file. @@ -1095,7 +1123,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None): """ options = options or save_options.SaveOptions() export_dir = os.path.dirname(filename) - meta_graph_def, exported_graph, _, _ = _build_meta_graph( + meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph( obj, signatures, options) file_io.atomic_write_string_to_file( @@ -1114,7 +1142,8 @@ def export_meta_graph(obj, filename, signatures=None, options=None): def _build_meta_graph_impl(obj, signatures, options, - meta_graph_def=None): + meta_graph_def=None, + raise_metadata_warning=True): """Creates a MetaGraph containing the resources and functions of an object.""" if ops.inside_function(): raise AssertionError( @@ -1162,7 +1191,7 @@ def _build_meta_graph_impl(obj, saveable_view, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) - if saved_object_metadata: + if saved_object_metadata and raise_metadata_warning: tf_logging.warn( 'FOR KERAS USERS: The object that you are saving contains one or more ' 'Keras models or layers. If you are loading the SavedModel with ' @@ -1178,13 +1207,15 @@ def _build_meta_graph_impl(obj, 'metadta field will be deprecated soon, so please move the metadata to ' 'a different file.') - return (meta_graph_def, exported_graph, object_saver, asset_info) + return (meta_graph_def, exported_graph, object_saver, asset_info, + saveable_view.nodes, saveable_view.node_paths) def _build_meta_graph(obj, signatures, options, - meta_graph_def=None): + meta_graph_def=None, + raise_metadata_warning=True): """Creates a MetaGraph under a save context. Args: @@ -1197,6 +1228,8 @@ def _build_meta_graph(obj, options: `tf.saved_model.SaveOptions` object that specifies options for saving. meta_graph_def: Optional, the MetaGraphDef proto fill. + raise_metadata_warning: Whether to raise a warning when user objects contain + non-empty metadata. Raises: AssertionError: If `export_meta_graph` is executing inside a `tf.function`. @@ -1210,4 +1243,5 @@ def _build_meta_graph(obj, """ with save_context.save_context(options): - return _build_meta_graph_impl(obj, signatures, options, meta_graph_def) + return _build_meta_graph_impl(obj, signatures, options, meta_graph_def, + raise_metadata_warning) diff --git a/tensorflow/python/training/tracking/graph_view.py b/tensorflow/python/training/tracking/graph_view.py index 6aeb41b47a9..61078ccf91d 100644 --- a/tensorflow/python/training/tracking/graph_view.py +++ b/tensorflow/python/training/tracking/graph_view.py @@ -430,7 +430,7 @@ class ObjectGraphView(object): name=base.OBJECT_GRAPH_PROTO_KEY)) return named_saveable_objects - def objects_ids_and_slot_variables(self): + def objects_ids_and_slot_variables_and_paths(self): """Traverse the object graph and list all accessible objects. Looks for `Trackable` objects which are dependencies of @@ -439,7 +439,8 @@ class ObjectGraphView(object): (i.e. if they would be saved with a checkpoint). Returns: - A tuple of (trackable objects, object -> node id, slot variables) + A tuple of (trackable objects, paths from root for each object, + object -> node id, slot variables) """ trackable_objects, path_to_root = self._breadth_first_traversal() object_names = object_identity.ObjectIdentityDictionary() @@ -452,6 +453,11 @@ class ObjectGraphView(object): trackable_objects=trackable_objects, node_ids=node_ids, object_names=object_names) + return trackable_objects, path_to_root, node_ids, slot_variables + + def objects_ids_and_slot_variables(self): + trackable_objects, _, node_ids, slot_variables = ( + self.objects_ids_and_slot_variables_and_paths()) return trackable_objects, node_ids, slot_variables def list_objects(self):