From 5190387363d711c9eacef0d9e3f49e14bc8d7fde Mon Sep 17 00:00:00 2001
From: Kathy Wu <kathywu@google.com>
Date: Thu, 3 Dec 2020 16:55:21 -0800
Subject: [PATCH 1/2] Revert "Save Keras metadata in a separate proto and raise
 deprecation warnings when loading a SavedModel with tf.saved_model.save()."

This reverts commit 87fc5a0509d9990f4c2c5486521fc143fd3e8c1c.
---
 tensorflow/python/keras/saving/BUILD          |  1 -
 .../keras/saving/saved_model/constants.py     |  4 --
 .../python/keras/saving/saved_model/load.py   | 25 ++-------
 .../python/keras/saving/saved_model/save.py   | 40 +-------------
 tensorflow/python/saved_model/save.py         | 54 ++++---------------
 .../python/training/tracking/graph_view.py    | 10 +---
 6 files changed, 17 insertions(+), 117 deletions(-)

diff --git a/tensorflow/python/keras/saving/BUILD b/tensorflow/python/keras/saving/BUILD
index 7dcc9ae49a6..51095c1c75f 100644
--- a/tensorflow/python/keras/saving/BUILD
+++ b/tensorflow/python/keras/saving/BUILD
@@ -49,7 +49,6 @@ py_library(
     deps = [
         "//tensorflow/python:lib",
         "//tensorflow/python:math_ops",
-        "//tensorflow/python:platform",
         "//tensorflow/python:saver",
         "//tensorflow/python:tensor_spec",
         "//tensorflow/python/eager:def_function",
diff --git a/tensorflow/python/keras/saving/saved_model/constants.py b/tensorflow/python/keras/saving/saved_model/constants.py
index 12265e0a3f3..3f1eca9c500 100644
--- a/tensorflow/python/keras/saving/saved_model/constants.py
+++ b/tensorflow/python/keras/saving/saved_model/constants.py
@@ -26,7 +26,3 @@ KERAS_ATTR = 'keras_api'
 # Keys for the serialization cache.
 # Maps to the keras serialization dict {Layer --> SerializedAttributes object}
 KERAS_CACHE_KEY = 'keras_serialized_attributes'
-
-
-# Name of Keras metadata file stored in the SavedModel.
-SAVED_METADATA_PATH = 'keras_metadata.pb'
diff --git a/tensorflow/python/keras/saving/saved_model/load.py b/tensorflow/python/keras/saving/saved_model/load.py
index 43c1d2bd0d4..cb6d340ea03 100644
--- a/tensorflow/python/keras/saving/saved_model/load.py
+++ b/tensorflow/python/keras/saving/saved_model/load.py
@@ -17,12 +17,9 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import os
 import re
 import types
 
-from google.protobuf import message
-
 from tensorflow.core.framework import versions_pb2
 from tensorflow.python.eager import context
 from tensorflow.python.eager import function as defun
@@ -41,7 +38,6 @@ from tensorflow.python.keras.saving.saved_model.serialized_attributes import Com
 from tensorflow.python.keras.utils import generic_utils
 from tensorflow.python.keras.utils import metrics_utils
 from tensorflow.python.keras.utils.generic_utils import LazyLoader
-from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.saved_model import load as tf_load
 from tensorflow.python.saved_model import loader_impl
@@ -125,26 +121,13 @@ def load(path, compile=True, options=None):  # pylint: disable=redefined-builtin
   # TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
   # TODO(kathywu): Add code to load from objects that contain all endpoints
 
-  # Look for metadata file or parse the SavedModel
+  # The Keras metadata file is not yet saved, so create it from the SavedModel.
   metadata = saved_metadata_pb2.SavedMetadata()
   meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0]
   object_graph_def = meta_graph_def.object_graph_def
-  path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH)
-  if gfile.Exists(path_to_metadata_pb):
-    try:
-      with gfile.GFile(path_to_metadata_pb, 'rb') as f:
-        file_content = f.read()
-      metadata.ParseFromString(file_content)
-    except message.DecodeError as e:
-      raise IOError('Cannot parse keras metadata {}: {}.'
-                    .format(path_to_metadata_pb, str(e)))
-  else:
-    logging.warning('SavedModel saved prior to TF 2.4 detected when loading '
-                    'Keras model. Please ensure that you are saving the model '
-                    'with model.save() or tf.keras.models.save_model(), *NOT* '
-                    'tf.saved_model.save(). To confirm, there should be a file '
-                    'named "keras_metadata.pb" in the SavedModel directory.')
-    _read_legacy_metadata(object_graph_def, metadata)
+  # TODO(kathywu): When the keras metadata file is saved, load it directly
+  # instead of calling the _read_legacy_metadata function.
+  _read_legacy_metadata(object_graph_def, metadata)
 
   if not metadata.nodes:
     # When there are no Keras objects, return the results from the core loader
diff --git a/tensorflow/python/keras/saving/saved_model/save.py b/tensorflow/python/keras/saving/saved_model/save.py
index 2ab7ebb60b1..16984a2221b 100644
--- a/tensorflow/python/keras/saving/saved_model/save.py
+++ b/tensorflow/python/keras/saving/saved_model/save.py
@@ -18,21 +18,15 @@ from __future__ import division
 from __future__ import print_function
 
 import os
-
-from tensorflow.core.framework import versions_pb2
 from tensorflow.python.distribute import distribution_strategy_context
 from tensorflow.python.keras import backend as K
-from tensorflow.python.keras.protobuf import saved_metadata_pb2
 from tensorflow.python.keras.saving import saving_utils
-from tensorflow.python.keras.saving.saved_model import constants
 from tensorflow.python.keras.saving.saved_model import save_impl
 from tensorflow.python.keras.saving.saved_model import utils
 from tensorflow.python.keras.utils.generic_utils import LazyLoader
 from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
-from tensorflow.python.platform import gfile
 from tensorflow.python.saved_model import save as save_lib
 
-
 # To avoid circular dependencies between keras/engine and keras/saving,
 # code in keras/saving must delay imports.
 
@@ -92,39 +86,7 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
     # we use the default replica context here.
     with distribution_strategy_context._get_default_replica_context():  # pylint: disable=protected-access
       with utils.keras_option_scope(save_traces):
-        saved_nodes, node_paths = save_lib.save_and_return_nodes(
-            model, filepath, signatures, options)
-
-    # Save all metadata to a separate file in the SavedModel directory.
-    metadata = generate_keras_metadata(saved_nodes, node_paths)
-
-  with gfile.GFile(
-      os.path.join(filepath, constants.SAVED_METADATA_PATH), "wb") as w:
-    w.write(metadata.SerializeToString(deterministic=True))
+        save_lib.save(model, filepath, signatures, options)
 
   if not include_optimizer:
     model.optimizer = orig_optimizer
-
-
-def generate_keras_metadata(saved_nodes, node_paths):
-  """Constructs a KerasMetadata proto with the metadata of each keras object."""
-  metadata = saved_metadata_pb2.SavedMetadata()
-
-  for node_id, node in enumerate(saved_nodes):
-    if isinstance(node, base_layer.Layer):
-      path = node_paths[node]
-      if not path:
-        node_path = "root"
-      else:
-        node_path = "root.{}".format(
-            ".".join([ref.name for ref in path]))
-
-      metadata.nodes.add(
-          node_id=node_id,
-          node_path=node_path,
-          version=versions_pb2.VersionDef(
-              producer=1, min_consumer=1, bad_consumers=[]),
-          identifier=node._object_identifier,  # pylint: disable=protected-access
-          metadata=node._tracking_metadata)  # pylint: disable=protected-access
-
-  return metadata
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 45d135d2e61..87a65724ab9 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -183,9 +183,8 @@ class _SaveableView(object):
     """
     self.options = options
     self.checkpoint_view = checkpoint_view
-    trackable_objects, path_to_root, node_ids, slot_variables = (
-        self.checkpoint_view.objects_ids_and_slot_variables_and_paths())
-    self.node_paths = path_to_root
+    trackable_objects, node_ids, slot_variables = (
+        self.checkpoint_view.objects_ids_and_slot_variables())
     self.nodes = trackable_objects
     self.node_ids = node_ids
     self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
@@ -1030,30 +1029,6 @@ def save(obj, export_dir, signatures=None, options=None):
   May not be called from within a function body.
   @end_compatibility
   """
-  save_and_return_nodes(obj, export_dir, signatures, options,
-                        raise_metadata_warning=True)
-
-
-def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
-                          raise_metadata_warning=False):
-  """Saves a SavedModel while returning all saved nodes and their paths.
-
-  Please see `tf.saved_model.save` for details.
-
-  Args:
-    obj: A trackable object to export.
-    export_dir: A directory in which to write the SavedModel.
-    signatures: A function or dictionary of functions to save in the SavedModel
-      as signatures.
-    options: `tf.saved_model.SaveOptions` object for configuring save options.
-    raise_metadata_warning: Whether to raise the metadata warning. This arg will
-      be removed in TF 2.5.
-
-  Returns:
-    A tuple of (a list of saved nodes in the order they are serialized to the
-      `SavedObjectGraph`, dictionary mapping nodes to one possible path from
-      the root node to the key node)
-  """
   options = options or save_options.SaveOptions()
   # TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
   # compatible (no sessions) and share it with this export API rather than
@@ -1061,9 +1036,8 @@ def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
   saved_model = saved_model_pb2.SavedModel()
   meta_graph_def = saved_model.meta_graphs.add()
 
-  _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-      _build_meta_graph(obj, signatures, options, meta_graph_def,
-                        raise_metadata_warning))
+  _, exported_graph, object_saver, asset_info = _build_meta_graph(
+      obj, signatures, options, meta_graph_def)
   saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
 
   # Write the checkpoint, copy assets into the assets directory, and write out
@@ -1103,8 +1077,6 @@ def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
   # constants in the saved graph.
   ops.dismantle_graph(exported_graph)
 
-  return saved_nodes, node_paths
-
 
 def export_meta_graph(obj, filename, signatures=None, options=None):
   """Exports the MetaGraph proto of the `obj` to a file.
@@ -1131,7 +1103,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None):
   """
   options = options or save_options.SaveOptions()
   export_dir = os.path.dirname(filename)
-  meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph(
+  meta_graph_def, exported_graph, _, _ = _build_meta_graph(
       obj, signatures, options)
 
   file_io.atomic_write_string_to_file(
@@ -1150,8 +1122,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None):
 def _build_meta_graph_impl(obj,
                            signatures,
                            options,
-                           meta_graph_def=None,
-                           raise_metadata_warning=True):
+                           meta_graph_def=None):
   """Creates a MetaGraph containing the resources and functions of an object."""
   if ops.inside_function():
     raise AssertionError(
@@ -1199,7 +1170,7 @@ def _build_meta_graph_impl(obj,
       saveable_view, asset_info.asset_index)
   meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
 
-  if saved_object_metadata and raise_metadata_warning:
+  if saved_object_metadata:
     tf_logging.warn(
         'FOR KERAS USERS: The object that you are saving contains one or more '
         'Keras models or layers. If you are loading the SavedModel with '
@@ -1215,15 +1186,13 @@ def _build_meta_graph_impl(obj,
         'metadta field will be deprecated soon, so please move the metadata to '
         'a different file.')
 
-  return (meta_graph_def, exported_graph, object_saver, asset_info,
-          saveable_view.nodes, saveable_view.node_paths)
+  return (meta_graph_def, exported_graph, object_saver, asset_info)
 
 
 def _build_meta_graph(obj,
                       signatures,
                       options,
-                      meta_graph_def=None,
-                      raise_metadata_warning=True):
+                      meta_graph_def=None):
   """Creates a MetaGraph under a save context.
 
   Args:
@@ -1236,8 +1205,6 @@ def _build_meta_graph(obj,
     options: `tf.saved_model.SaveOptions` object that specifies options for
       saving.
     meta_graph_def: Optional, the MetaGraphDef proto fill.
-    raise_metadata_warning: Whether to raise a warning when user objects contain
-      non-empty metadata.
 
   Raises:
     AssertionError: If `export_meta_graph` is executing inside a `tf.function`.
@@ -1251,5 +1218,4 @@ def _build_meta_graph(obj,
   """
 
   with save_context.save_context(options):
-    return _build_meta_graph_impl(obj, signatures, options, meta_graph_def,
-                                  raise_metadata_warning)
+    return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
diff --git a/tensorflow/python/training/tracking/graph_view.py b/tensorflow/python/training/tracking/graph_view.py
index 61078ccf91d..6aeb41b47a9 100644
--- a/tensorflow/python/training/tracking/graph_view.py
+++ b/tensorflow/python/training/tracking/graph_view.py
@@ -430,7 +430,7 @@ class ObjectGraphView(object):
               name=base.OBJECT_GRAPH_PROTO_KEY))
     return named_saveable_objects
 
-  def objects_ids_and_slot_variables_and_paths(self):
+  def objects_ids_and_slot_variables(self):
     """Traverse the object graph and list all accessible objects.
 
     Looks for `Trackable` objects which are dependencies of
@@ -439,8 +439,7 @@ class ObjectGraphView(object):
     (i.e. if they would be saved with a checkpoint).
 
     Returns:
-      A tuple of (trackable objects, paths from root for each object,
-                  object -> node id, slot variables)
+      A tuple of (trackable objects, object -> node id, slot variables)
     """
     trackable_objects, path_to_root = self._breadth_first_traversal()
     object_names = object_identity.ObjectIdentityDictionary()
@@ -453,11 +452,6 @@ class ObjectGraphView(object):
         trackable_objects=trackable_objects,
         node_ids=node_ids,
         object_names=object_names)
-    return trackable_objects, path_to_root, node_ids, slot_variables
-
-  def objects_ids_and_slot_variables(self):
-    trackable_objects, _, node_ids, slot_variables = (
-        self.objects_ids_and_slot_variables_and_paths())
     return trackable_objects, node_ids, slot_variables
 
   def list_objects(self):

From 45967fee602a4c30223b5e701526784dc0ad38d0 Mon Sep 17 00:00:00 2001
From: Kathy Wu <kathywu@google.com>
Date: Thu, 3 Dec 2020 16:55:31 -0800
Subject: [PATCH 2/2] Revert "Warn users when saving SavedModel with metadata."

This reverts commit 0ed710fb7625e3d099c7928905eefb027a8f65cb.
---
 .../core/protobuf/saved_object_graph.proto    |  3 --
 .../saving/saved_model/saved_model_test.py    | 13 +++---
 tensorflow/python/saved_model/BUILD           |  1 -
 tensorflow/python/saved_model/save.py         | 41 ++++---------------
 4 files changed, 15 insertions(+), 43 deletions(-)

diff --git a/tensorflow/core/protobuf/saved_object_graph.proto b/tensorflow/core/protobuf/saved_object_graph.proto
index 8df58683ead..a5b4cfbe823 100644
--- a/tensorflow/core/protobuf/saved_object_graph.proto
+++ b/tensorflow/core/protobuf/saved_object_graph.proto
@@ -76,9 +76,6 @@ message SavedUserObject {
   string identifier = 1;
   // Version information from the producer of this SavedUserObject.
   VersionDef version = 2;
-  // Deprecated! At the time of deprecation, Keras was the only user of this
-  // field, and its saving and loading code will be updated shortly.
-  // Please save your application-specific metadata to separate file
   // Initialization-related metadata.
   string metadata = 3;
 }
diff --git a/tensorflow/python/keras/saving/saved_model/saved_model_test.py b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
index 12a3a7761b8..726ef570da8 100644
--- a/tensorflow/python/keras/saving/saved_model/saved_model_test.py
+++ b/tensorflow/python/keras/saving/saved_model/saved_model_test.py
@@ -26,6 +26,7 @@ from __future__ import print_function
 
 import os
 import shutil
+import sys
 
 from absl.testing import parameterized
 import numpy as np
@@ -410,16 +411,14 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
     self.evaluate(variables.variables_initializer(model.variables))
     saved_model_dir = self._save_model_dir()
 
-    # TODO(kathywu): Re-enable this check after removing the tf.saved_model.save
-    # metadata warning.
-    # with self.captureWritesToStream(sys.stderr) as captured_logs:
-    model.save(saved_model_dir, save_format='tf')
-    loaded = keras_load.load(saved_model_dir)
+    with self.captureWritesToStream(sys.stderr) as captured_logs:
+      model.save(saved_model_dir, save_format='tf')
+      loaded = keras_load.load(saved_model_dir)
 
     # Assert that saving does not log deprecation warnings
     # (even if it needs to set learning phase for compat reasons)
-    # if context.executing_eagerly():
-    #   self.assertNotIn('deprecated', captured_logs.contents())
+    if context.executing_eagerly():
+      self.assertNotIn('deprecated', captured_logs.contents())
 
     input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32)
     input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32)
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index aac1062cebe..35431503964 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -349,7 +349,6 @@ py_strict_library(
         "//tensorflow/python:framework",
         "//tensorflow/python:framework_ops",
         "//tensorflow/python:lib",
-        "//tensorflow/python:platform",
         "//tensorflow/python:resource_variable_ops",
         "//tensorflow/python:tensor_util",
         "//tensorflow/python:tf_export",
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 87a65724ab9..dbf7169d43e 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -43,7 +43,6 @@ from tensorflow.python.lib.io import file_io
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.platform import tf_logging
 from tensorflow.python.saved_model import builder_impl
 from tensorflow.python.saved_model import constants
 from tensorflow.python.saved_model import function_serialization
@@ -751,17 +750,14 @@ def _serialize_object_graph(saveable_view, asset_file_def_index):
     if serialized is not None:
       proto.concrete_functions[name].CopyFrom(serialized)
 
-  saved_object_metadata = False
   for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
-    has_saved_object_metadata = _write_object_proto(
-        obj, obj_proto, asset_file_def_index, saveable_view.function_name_map)
-    saved_object_metadata = saved_object_metadata or has_saved_object_metadata
-  return proto, saved_object_metadata
+    _write_object_proto(obj, obj_proto, asset_file_def_index,
+                        saveable_view.function_name_map)
+  return proto
 
 
 def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
   """Saves an object into SavedObject proto."""
-  has_saved_object_metadata = False  # The metadata field will be deprecated.
   if isinstance(obj, tracking.Asset):
     proto.asset.SetInParent()
     proto.asset.asset_file_def_index = asset_file_def_index[obj]
@@ -797,14 +793,11 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
     if registered_type_proto is None:
       # Fallback for types with no matching registration
       # pylint:disable=protected-access
-      metadata = obj._tracking_metadata
-      if metadata:
-        has_saved_object_metadata = True
       registered_type_proto = saved_object_graph_pb2.SavedUserObject(
           identifier=obj._object_identifier,
           version=versions_pb2.VersionDef(
               producer=1, min_consumer=1, bad_consumers=[]),
-          metadata=metadata)
+          metadata=obj._tracking_metadata)
       # pylint:enable=protected-access
     proto.user_object.CopyFrom(registered_type_proto)
 
@@ -817,7 +810,6 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
   # documentation.
   if hasattr(obj, "_write_object_proto"):
     obj._write_object_proto(proto, options)  # pylint: disable=protected-access
-  return has_saved_object_metadata
 
 
 def _export_debug_info(exported_graph, export_dir):
@@ -1015,7 +1007,8 @@ def save(obj, export_dir, signatures=None, options=None):
         instances with input signatures or concrete functions. Keys of such a
         dictionary may be arbitrary strings, but will typically be from the
         `tf.saved_model.signature_constants` module.
-    options: `tf.saved_model.SaveOptions` object for configuring save options.
+    options: Optional, `tf.saved_model.SaveOptions` object that specifies
+      options for saving.
 
   Raises:
     ValueError: If `obj` is not trackable.
@@ -1166,27 +1159,11 @@ def _build_meta_graph_impl(obj,
       for fdef in func._stateless_fn._function_cache.all_values():  # pylint: disable=protected-access
         function_aliases[fdef.name] = alias
 
-  object_graph_proto, saved_object_metadata = _serialize_object_graph(
-      saveable_view, asset_info.asset_index)
+  object_graph_proto = _serialize_object_graph(saveable_view,
+                                               asset_info.asset_index)
   meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
 
-  if saved_object_metadata:
-    tf_logging.warn(
-        'FOR KERAS USERS: The object that you are saving contains one or more '
-        'Keras models or layers. If you are loading the SavedModel with '
-        '`tf.keras.models.load_model`, continue reading (otherwise, you may '
-        'ignore the following instructions). Please change your code to save '
-        'with `tf.keras.models.save_model` or `model.save`, and confirm that '
-        'the file "keras.metadata" exists in the export directory. In the '
-        'future, Keras will only load the SavedModels that have this file. In '
-        'other words, `tf.saved_model.save` will no longer write SavedModels '
-        'that can be recovered as Keras models (this will apply in TF 2.5).'
-        '\n\nFOR DEVS: If you are overwriting _tracking_metadata in your class,'
-        ' this property has been used to save metadata in the SavedModel. The '
-        'metadta field will be deprecated soon, so please move the metadata to '
-        'a different file.')
-
-  return (meta_graph_def, exported_graph, object_saver, asset_info)
+  return meta_graph_def, exported_graph, object_saver, asset_info
 
 
 def _build_meta_graph(obj,