Warn users when saving SavedModel with metadata.

The metadata field will no longer be used by Keras. Since Keras is the only consumer metadata field, this field will be deprecated shortly.

PiperOrigin-RevId: 338353130
Change-Id: I762b7b223255966c78b5b362b0d07ec27351bb42
This commit is contained in:
Katherine Wu 2020-10-21 15:30:36 -07:00 committed by TensorFlower Gardener
parent 201500c78d
commit 0ed710fb76
4 changed files with 43 additions and 15 deletions
tensorflow
core/protobuf
python
keras/saving/saved_model
saved_model

View File

@ -76,6 +76,9 @@ message SavedUserObject {
string identifier = 1;
// Version information from the producer of this SavedUserObject.
VersionDef version = 2;
// Deprecated! At the time of deprecation, Keras was the only user of this
// field, and its saving and loading code will be updated shortly.
// Please save your application-specific metadata to separate file
// Initialization-related metadata.
string metadata = 3;
}

View File

@ -26,7 +26,6 @@ from __future__ import print_function
import os
import shutil
import sys
from absl.testing import parameterized
import numpy as np
@ -411,14 +410,16 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
self.evaluate(variables.variables_initializer(model.variables))
saved_model_dir = self._save_model_dir()
with self.captureWritesToStream(sys.stderr) as captured_logs:
model.save(saved_model_dir, save_format='tf')
loaded = keras_load.load(saved_model_dir)
# TODO(kathywu): Re-enable this check after removing the tf.saved_model.save
# metadata warning.
# with self.captureWritesToStream(sys.stderr) as captured_logs:
model.save(saved_model_dir, save_format='tf')
loaded = keras_load.load(saved_model_dir)
# Assert that saving does not log deprecation warnings
# (even if it needs to set learning phase for compat reasons)
if context.executing_eagerly():
self.assertNotIn('deprecated', captured_logs.contents())
# if context.executing_eagerly():
# self.assertNotIn('deprecated', captured_logs.contents())
input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32)
input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32)

View File

@ -349,6 +349,7 @@ py_strict_library(
"//tensorflow/python:framework",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:tf_export",

View File

@ -43,6 +43,7 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import builder_impl
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import function_serialization
@ -742,14 +743,17 @@ def _serialize_object_graph(saveable_view, asset_file_def_index):
if serialized is not None:
proto.concrete_functions[name].CopyFrom(serialized)
saved_object_metadata = False
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
_write_object_proto(obj, obj_proto, asset_file_def_index,
saveable_view.function_name_map)
return proto
has_saved_object_metadata = _write_object_proto(
obj, obj_proto, asset_file_def_index, saveable_view.function_name_map)
saved_object_metadata = saved_object_metadata or has_saved_object_metadata
return proto, saved_object_metadata
def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
"""Saves an object into SavedObject proto."""
has_saved_object_metadata = False # The metadata field will be deprecated.
if isinstance(obj, tracking.Asset):
proto.asset.SetInParent()
proto.asset.asset_file_def_index = asset_file_def_index[obj]
@ -785,11 +789,14 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
if registered_type_proto is None:
# Fallback for types with no matching registration
# pylint:disable=protected-access
metadata = obj._tracking_metadata
if metadata:
has_saved_object_metadata = True
registered_type_proto = saved_object_graph_pb2.SavedUserObject(
identifier=obj._object_identifier,
version=versions_pb2.VersionDef(
producer=1, min_consumer=1, bad_consumers=[]),
metadata=obj._tracking_metadata)
metadata=metadata)
# pylint:enable=protected-access
proto.user_object.CopyFrom(registered_type_proto)
@ -802,6 +809,7 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
# documentation.
if hasattr(obj, "_write_object_proto"):
obj._write_object_proto(proto, options) # pylint: disable=protected-access
return has_saved_object_metadata
def _export_debug_info(exported_graph, export_dir):
@ -999,8 +1007,7 @@ def save(obj, export_dir, signatures=None, options=None):
instances with input signatures or concrete functions. Keys of such a
dictionary may be arbitrary strings, but will typically be from the
`tf.saved_model.signature_constants` module.
options: Optional, `tf.saved_model.SaveOptions` object that specifies
options for saving.
options: `tf.saved_model.SaveOptions` object for configuring save options.
Raises:
ValueError: If `obj` is not trackable.
@ -1151,11 +1158,27 @@ def _build_meta_graph_impl(obj,
for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access
function_aliases[fdef.name] = alias
object_graph_proto = _serialize_object_graph(saveable_view,
asset_info.asset_index)
object_graph_proto, saved_object_metadata = _serialize_object_graph(
saveable_view, asset_info.asset_index)
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
return meta_graph_def, exported_graph, object_saver, asset_info
if saved_object_metadata:
tf_logging.warn(
'FOR KERAS USERS: The object that you are saving contains one or more '
'Keras models or layers. If you are loading the SavedModel with '
'`tf.keras.models.load_model`, continue reading (otherwise, you may '
'ignore the following instructions). Please change your code to save '
'with `tf.keras.models.save_model` or `model.save`, and confirm that '
'the file "keras.metadata" exists in the export directory. In the '
'future, Keras will only load the SavedModels that have this file. In '
'other words, `tf.saved_model.save` will no longer write SavedModels '
'that can be recovered as Keras models (this will apply in TF 2.5).'
'\n\nFOR DEVS: If you are overwriting _tracking_metadata in your class,'
' this property has been used to save metadata in the SavedModel. The '
'metadta field will be deprecated soon, so please move the metadata to '
'a different file.')
return (meta_graph_def, exported_graph, object_saver, asset_info)
def _build_meta_graph(obj,