From 5190387363d711c9eacef0d9e3f49e14bc8d7fde Mon Sep 17 00:00:00 2001 From: Kathy Wu <kathywu@google.com> Date: Thu, 3 Dec 2020 16:55:21 -0800 Subject: [PATCH 1/2] Revert "Save Keras metadata in a separate proto and raise deprecation warnings when loading a SavedModel with tf.saved_model.save()." This reverts commit 87fc5a0509d9990f4c2c5486521fc143fd3e8c1c. --- tensorflow/python/keras/saving/BUILD | 1 - .../keras/saving/saved_model/constants.py | 4 -- .../python/keras/saving/saved_model/load.py | 25 ++------- .../python/keras/saving/saved_model/save.py | 40 +------------- tensorflow/python/saved_model/save.py | 54 ++++--------------- .../python/training/tracking/graph_view.py | 10 +--- 6 files changed, 17 insertions(+), 117 deletions(-) diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD index 7dcc9ae49a6..51095c1c75f 100644 --- a/tensorflow/python/keras/saving/BUILD +++ b/tensorflow/python/keras/saving/BUILD @@ -49,7 +49,6 @@ py_library( deps = [ "//tensorflow/python:lib", "//tensorflow/python:math_ops", - "//tensorflow/python:platform", "//tensorflow/python:saver", "//tensorflow/python:tensor_spec", "//tensorflow/python/eager:def_function", diff --git a/tensorflow/python/keras/saving/saved_model/constants.py b/tensorflow/python/keras/saving/saved_model/constants.py index 12265e0a3f3..3f1eca9c500 100644 --- a/tensorflow/python/keras/saving/saved_model/constants.py +++ b/tensorflow/python/keras/saving/saved_model/constants.py @@ -26,7 +26,3 @@ KERAS_ATTR = 'keras_api' # Keys for the serialization cache. # Maps to the keras serialization dict {Layer --> SerializedAttributes object} KERAS_CACHE_KEY = 'keras_serialized_attributes' - - -# Name of Keras metadata file stored in the SavedModel. -SAVED_METADATA_PATH = 'keras_metadata.pb' diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py index 43c1d2bd0d4..cb6d340ea03 100644 --- a/tensorflow/python/keras/saving/saved_model/load.py +++ b/tensorflow/python/keras/saving/saved_model/load.py @@ -17,12 +17,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import re import types -from google.protobuf import message - from tensorflow.core.framework import versions_pb2 from tensorflow.python.eager import context from tensorflow.python.eager import function as defun @@ -41,7 +38,6 @@ from tensorflow.python.keras.saving.saved_model.serialized_attributes import Com from tensorflow.python.keras.utils import generic_utils from tensorflow.python.keras.utils import metrics_utils from tensorflow.python.keras.utils.generic_utils import LazyLoader -from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import load as tf_load from tensorflow.python.saved_model import loader_impl @@ -125,26 +121,13 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics. # TODO(kathywu): Add code to load from objects that contain all endpoints - # Look for metadata file or parse the SavedModel + # The Keras metadata file is not yet saved, so create it from the SavedModel. metadata = saved_metadata_pb2.SavedMetadata() meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0] object_graph_def = meta_graph_def.object_graph_def - path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH) - if gfile.Exists(path_to_metadata_pb): - try: - with gfile.GFile(path_to_metadata_pb, 'rb') as f: - file_content = f.read() - metadata.ParseFromString(file_content) - except message.DecodeError as e: - raise IOError('Cannot parse keras metadata {}: {}.' - .format(path_to_metadata_pb, str(e))) - else: - logging.warning('SavedModel saved prior to TF 2.4 detected when loading ' - 'Keras model. Please ensure that you are saving the model ' - 'with model.save() or tf.keras.models.save_model(), *NOT* ' - 'tf.saved_model.save(). To confirm, there should be a file ' - 'named "keras_metadata.pb" in the SavedModel directory.') - _read_legacy_metadata(object_graph_def, metadata) + # TODO(kathywu): When the keras metadata file is saved, load it directly + # instead of calling the _read_legacy_metadata function. + _read_legacy_metadata(object_graph_def, metadata) if not metadata.nodes: # When there are no Keras objects, return the results from the core loader diff --git a/tensorflow/python/keras/saving/saved_model/save.py b/tensorflow/python/keras/saving/saved_model/save.py index 2ab7ebb60b1..16984a2221b 100644 --- a/tensorflow/python/keras/saving/saved_model/save.py +++ b/tensorflow/python/keras/saving/saved_model/save.py @@ -18,21 +18,15 @@ from __future__ import division from __future__ import print_function import os - -from tensorflow.core.framework import versions_pb2 from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.keras import backend as K -from tensorflow.python.keras.protobuf import saved_metadata_pb2 from tensorflow.python.keras.saving import saving_utils -from tensorflow.python.keras.saving.saved_model import constants from tensorflow.python.keras.saving.saved_model import save_impl from tensorflow.python.keras.saving.saved_model import utils from tensorflow.python.keras.utils.generic_utils import LazyLoader from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite -from tensorflow.python.platform import gfile from tensorflow.python.saved_model import save as save_lib - # To avoid circular dependencies between keras/engine and keras/saving, # code in keras/saving must delay imports. @@ -92,39 +86,7 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None, # we use the default replica context here. with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access with utils.keras_option_scope(save_traces): - saved_nodes, node_paths = save_lib.save_and_return_nodes( - model, filepath, signatures, options) - - # Save all metadata to a separate file in the SavedModel directory. - metadata = generate_keras_metadata(saved_nodes, node_paths) - - with gfile.GFile( - os.path.join(filepath, constants.SAVED_METADATA_PATH), "wb") as w: - w.write(metadata.SerializeToString(deterministic=True)) + save_lib.save(model, filepath, signatures, options) if not include_optimizer: model.optimizer = orig_optimizer - - -def generate_keras_metadata(saved_nodes, node_paths): - """Constructs a KerasMetadata proto with the metadata of each keras object.""" - metadata = saved_metadata_pb2.SavedMetadata() - - for node_id, node in enumerate(saved_nodes): - if isinstance(node, base_layer.Layer): - path = node_paths[node] - if not path: - node_path = "root" - else: - node_path = "root.{}".format( - ".".join([ref.name for ref in path])) - - metadata.nodes.add( - node_id=node_id, - node_path=node_path, - version=versions_pb2.VersionDef( - producer=1, min_consumer=1, bad_consumers=[]), - identifier=node._object_identifier, # pylint: disable=protected-access - metadata=node._tracking_metadata) # pylint: disable=protected-access - - return metadata diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 45d135d2e61..87a65724ab9 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -183,9 +183,8 @@ class _SaveableView(object): """ self.options = options self.checkpoint_view = checkpoint_view - trackable_objects, path_to_root, node_ids, slot_variables = ( - self.checkpoint_view.objects_ids_and_slot_variables_and_paths()) - self.node_paths = path_to_root + trackable_objects, node_ids, slot_variables = ( + self.checkpoint_view.objects_ids_and_slot_variables()) self.nodes = trackable_objects self.node_ids = node_ids self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary() @@ -1030,30 +1029,6 @@ def save(obj, export_dir, signatures=None, options=None): May not be called from within a function body. @end_compatibility """ - save_and_return_nodes(obj, export_dir, signatures, options, - raise_metadata_warning=True) - - -def save_and_return_nodes(obj, export_dir, signatures=None, options=None, - raise_metadata_warning=False): - """Saves a SavedModel while returning all saved nodes and their paths. - - Please see `tf.saved_model.save` for details. - - Args: - obj: A trackable object to export. - export_dir: A directory in which to write the SavedModel. - signatures: A function or dictionary of functions to save in the SavedModel - as signatures. - options: `tf.saved_model.SaveOptions` object for configuring save options. - raise_metadata_warning: Whether to raise the metadata warning. This arg will - be removed in TF 2.5. - - Returns: - A tuple of (a list of saved nodes in the order they are serialized to the - `SavedObjectGraph`, dictionary mapping nodes to one possible path from - the root node to the key node) - """ options = options or save_options.SaveOptions() # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x # compatible (no sessions) and share it with this export API rather than @@ -1061,9 +1036,8 @@ def save_and_return_nodes(obj, export_dir, signatures=None, options=None, saved_model = saved_model_pb2.SavedModel() meta_graph_def = saved_model.meta_graphs.add() - _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = ( - _build_meta_graph(obj, signatures, options, meta_graph_def, - raise_metadata_warning)) + _, exported_graph, object_saver, asset_info = _build_meta_graph( + obj, signatures, options, meta_graph_def) saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION # Write the checkpoint, copy assets into the assets directory, and write out @@ -1103,8 +1077,6 @@ def save_and_return_nodes(obj, export_dir, signatures=None, options=None, # constants in the saved graph. ops.dismantle_graph(exported_graph) - return saved_nodes, node_paths - def export_meta_graph(obj, filename, signatures=None, options=None): """Exports the MetaGraph proto of the `obj` to a file. @@ -1131,7 +1103,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None): """ options = options or save_options.SaveOptions() export_dir = os.path.dirname(filename) - meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph( + meta_graph_def, exported_graph, _, _ = _build_meta_graph( obj, signatures, options) file_io.atomic_write_string_to_file( @@ -1150,8 +1122,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None): def _build_meta_graph_impl(obj, signatures, options, - meta_graph_def=None, - raise_metadata_warning=True): + meta_graph_def=None): """Creates a MetaGraph containing the resources and functions of an object.""" if ops.inside_function(): raise AssertionError( @@ -1199,7 +1170,7 @@ def _build_meta_graph_impl(obj, saveable_view, asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) - if saved_object_metadata and raise_metadata_warning: + if saved_object_metadata: tf_logging.warn( 'FOR KERAS USERS: The object that you are saving contains one or more ' 'Keras models or layers. If you are loading the SavedModel with ' @@ -1215,15 +1186,13 @@ def _build_meta_graph_impl(obj, 'metadta field will be deprecated soon, so please move the metadata to ' 'a different file.') - return (meta_graph_def, exported_graph, object_saver, asset_info, - saveable_view.nodes, saveable_view.node_paths) + return (meta_graph_def, exported_graph, object_saver, asset_info) def _build_meta_graph(obj, signatures, options, - meta_graph_def=None, - raise_metadata_warning=True): + meta_graph_def=None): """Creates a MetaGraph under a save context. Args: @@ -1236,8 +1205,6 @@ def _build_meta_graph(obj, options: `tf.saved_model.SaveOptions` object that specifies options for saving. meta_graph_def: Optional, the MetaGraphDef proto fill. - raise_metadata_warning: Whether to raise a warning when user objects contain - non-empty metadata. Raises: AssertionError: If `export_meta_graph` is executing inside a `tf.function`. @@ -1251,5 +1218,4 @@ def _build_meta_graph(obj, """ with save_context.save_context(options): - return _build_meta_graph_impl(obj, signatures, options, meta_graph_def, - raise_metadata_warning) + return _build_meta_graph_impl(obj, signatures, options, meta_graph_def) diff --git a/tensorflow/python/training/tracking/graph_view.py b/tensorflow/python/training/tracking/graph_view.py index 61078ccf91d..6aeb41b47a9 100644 --- a/tensorflow/python/training/tracking/graph_view.py +++ b/tensorflow/python/training/tracking/graph_view.py @@ -430,7 +430,7 @@ class ObjectGraphView(object): name=base.OBJECT_GRAPH_PROTO_KEY)) return named_saveable_objects - def objects_ids_and_slot_variables_and_paths(self): + def objects_ids_and_slot_variables(self): """Traverse the object graph and list all accessible objects. Looks for `Trackable` objects which are dependencies of @@ -439,8 +439,7 @@ class ObjectGraphView(object): (i.e. if they would be saved with a checkpoint). Returns: - A tuple of (trackable objects, paths from root for each object, - object -> node id, slot variables) + A tuple of (trackable objects, object -> node id, slot variables) """ trackable_objects, path_to_root = self._breadth_first_traversal() object_names = object_identity.ObjectIdentityDictionary() @@ -453,11 +452,6 @@ class ObjectGraphView(object): trackable_objects=trackable_objects, node_ids=node_ids, object_names=object_names) - return trackable_objects, path_to_root, node_ids, slot_variables - - def objects_ids_and_slot_variables(self): - trackable_objects, _, node_ids, slot_variables = ( - self.objects_ids_and_slot_variables_and_paths()) return trackable_objects, node_ids, slot_variables def list_objects(self): From 45967fee602a4c30223b5e701526784dc0ad38d0 Mon Sep 17 00:00:00 2001 From: Kathy Wu <kathywu@google.com> Date: Thu, 3 Dec 2020 16:55:31 -0800 Subject: [PATCH 2/2] Revert "Warn users when saving SavedModel with metadata." This reverts commit 0ed710fb7625e3d099c7928905eefb027a8f65cb. --- .../core/protobuf/saved_object_graph.proto | 3 -- .../saving/saved_model/saved_model_test.py | 13 +++--- tensorflow/python/saved_model/BUILD | 1 - tensorflow/python/saved_model/save.py | 41 ++++--------------- 4 files changed, 15 insertions(+), 43 deletions(-) diff --git a/tensorflow/core/protobuf/saved_object_graph.proto b/tensorflow/core/protobuf/saved_object_graph.proto index 8df58683ead..a5b4cfbe823 100644 --- a/tensorflow/core/protobuf/saved_object_graph.proto +++ b/tensorflow/core/protobuf/saved_object_graph.proto @@ -76,9 +76,6 @@ message SavedUserObject { string identifier = 1; // Version information from the producer of this SavedUserObject. VersionDef version = 2; - // Deprecated! At the time of deprecation, Keras was the only user of this - // field, and its saving and loading code will be updated shortly. - // Please save your application-specific metadata to separate file // Initialization-related metadata. string metadata = 3; } diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py index 12a3a7761b8..726ef570da8 100644 --- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py +++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py @@ -26,6 +26,7 @@ from __future__ import print_function import os import shutil +import sys from absl.testing import parameterized import numpy as np @@ -410,16 +411,14 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase): self.evaluate(variables.variables_initializer(model.variables)) saved_model_dir = self._save_model_dir() - # TODO(kathywu): Re-enable this check after removing the tf.saved_model.save - # metadata warning. - # with self.captureWritesToStream(sys.stderr) as captured_logs: - model.save(saved_model_dir, save_format='tf') - loaded = keras_load.load(saved_model_dir) + with self.captureWritesToStream(sys.stderr) as captured_logs: + model.save(saved_model_dir, save_format='tf') + loaded = keras_load.load(saved_model_dir) # Assert that saving does not log deprecation warnings # (even if it needs to set learning phase for compat reasons) - # if context.executing_eagerly(): - # self.assertNotIn('deprecated', captured_logs.contents()) + if context.executing_eagerly(): + self.assertNotIn('deprecated', captured_logs.contents()) input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32) input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32) diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index aac1062cebe..35431503964 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -349,7 +349,6 @@ py_strict_library( "//tensorflow/python:framework", "//tensorflow/python:framework_ops", "//tensorflow/python:lib", - "//tensorflow/python:platform", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:tf_export", diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 87a65724ab9..dbf7169d43e 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -43,7 +43,6 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import builder_impl from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import function_serialization @@ -751,17 +750,14 @@ def _serialize_object_graph(saveable_view, asset_file_def_index): if serialized is not None: proto.concrete_functions[name].CopyFrom(serialized) - saved_object_metadata = False for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): - has_saved_object_metadata = _write_object_proto( - obj, obj_proto, asset_file_def_index, saveable_view.function_name_map) - saved_object_metadata = saved_object_metadata or has_saved_object_metadata - return proto, saved_object_metadata + _write_object_proto(obj, obj_proto, asset_file_def_index, + saveable_view.function_name_map) + return proto def _write_object_proto(obj, proto, asset_file_def_index, function_name_map): """Saves an object into SavedObject proto.""" - has_saved_object_metadata = False # The metadata field will be deprecated. if isinstance(obj, tracking.Asset): proto.asset.SetInParent() proto.asset.asset_file_def_index = asset_file_def_index[obj] @@ -797,14 +793,11 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map): if registered_type_proto is None: # Fallback for types with no matching registration # pylint:disable=protected-access - metadata = obj._tracking_metadata - if metadata: - has_saved_object_metadata = True registered_type_proto = saved_object_graph_pb2.SavedUserObject( identifier=obj._object_identifier, version=versions_pb2.VersionDef( producer=1, min_consumer=1, bad_consumers=[]), - metadata=metadata) + metadata=obj._tracking_metadata) # pylint:enable=protected-access proto.user_object.CopyFrom(registered_type_proto) @@ -817,7 +810,6 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map): # documentation. if hasattr(obj, "_write_object_proto"): obj._write_object_proto(proto, options) # pylint: disable=protected-access - return has_saved_object_metadata def _export_debug_info(exported_graph, export_dir): @@ -1015,7 +1007,8 @@ def save(obj, export_dir, signatures=None, options=None): instances with input signatures or concrete functions. Keys of such a dictionary may be arbitrary strings, but will typically be from the `tf.saved_model.signature_constants` module. - options: `tf.saved_model.SaveOptions` object for configuring save options. + options: Optional, `tf.saved_model.SaveOptions` object that specifies + options for saving. Raises: ValueError: If `obj` is not trackable. @@ -1166,27 +1159,11 @@ def _build_meta_graph_impl(obj, for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access function_aliases[fdef.name] = alias - object_graph_proto, saved_object_metadata = _serialize_object_graph( - saveable_view, asset_info.asset_index) + object_graph_proto = _serialize_object_graph(saveable_view, + asset_info.asset_index) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) - if saved_object_metadata: - tf_logging.warn( - 'FOR KERAS USERS: The object that you are saving contains one or more ' - 'Keras models or layers. If you are loading the SavedModel with ' - '`tf.keras.models.load_model`, continue reading (otherwise, you may ' - 'ignore the following instructions). Please change your code to save ' - 'with `tf.keras.models.save_model` or `model.save`, and confirm that ' - 'the file "keras.metadata" exists in the export directory. In the ' - 'future, Keras will only load the SavedModels that have this file. In ' - 'other words, `tf.saved_model.save` will no longer write SavedModels ' - 'that can be recovered as Keras models (this will apply in TF 2.5).' - '\n\nFOR DEVS: If you are overwriting _tracking_metadata in your class,' - ' this property has been used to save metadata in the SavedModel. The ' - 'metadta field will be deprecated soon, so please move the metadata to ' - 'a different file.') - - return (meta_graph_def, exported_graph, object_saver, asset_info) + return meta_graph_def, exported_graph, object_saver, asset_info def _build_meta_graph(obj,