Warn users when saving SavedModel with metadata.

The metadata field will no longer be used by Keras. Since Keras is the only consumer metadata field, this field will be deprecated shortly.

PiperOrigin-RevId: 338353130
Change-Id: I762b7b223255966c78b5b362b0d07ec27351bb42
This commit is contained in:
Katherine Wu 2020-10-21 15:30:36 -07:00 committed by TensorFlower Gardener
parent 201500c78d
commit 0ed710fb76
4 changed files with 43 additions and 15 deletions

View File

@ -76,6 +76,9 @@ message SavedUserObject {
string identifier = 1; string identifier = 1;
// Version information from the producer of this SavedUserObject. // Version information from the producer of this SavedUserObject.
VersionDef version = 2; VersionDef version = 2;
// Deprecated! At the time of deprecation, Keras was the only user of this
// field, and its saving and loading code will be updated shortly.
// Please save your application-specific metadata to separate file
// Initialization-related metadata. // Initialization-related metadata.
string metadata = 3; string metadata = 3;
} }

View File

@ -26,7 +26,6 @@ from __future__ import print_function
import os import os
import shutil import shutil
import sys
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
@ -411,14 +410,16 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
self.evaluate(variables.variables_initializer(model.variables)) self.evaluate(variables.variables_initializer(model.variables))
saved_model_dir = self._save_model_dir() saved_model_dir = self._save_model_dir()
with self.captureWritesToStream(sys.stderr) as captured_logs: # TODO(kathywu): Re-enable this check after removing the tf.saved_model.save
model.save(saved_model_dir, save_format='tf') # metadata warning.
loaded = keras_load.load(saved_model_dir) # with self.captureWritesToStream(sys.stderr) as captured_logs:
model.save(saved_model_dir, save_format='tf')
loaded = keras_load.load(saved_model_dir)
# Assert that saving does not log deprecation warnings # Assert that saving does not log deprecation warnings
# (even if it needs to set learning phase for compat reasons) # (even if it needs to set learning phase for compat reasons)
if context.executing_eagerly(): # if context.executing_eagerly():
self.assertNotIn('deprecated', captured_logs.contents()) # self.assertNotIn('deprecated', captured_logs.contents())
input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32) input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32)
input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32) input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32)

View File

@ -349,6 +349,7 @@ py_strict_library(
"//tensorflow/python:framework", "//tensorflow/python:framework",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:lib", "//tensorflow/python:lib",
"//tensorflow/python:platform",
"//tensorflow/python:resource_variable_ops", "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_util", "//tensorflow/python:tensor_util",
"//tensorflow/python:tf_export", "//tensorflow/python:tf_export",

View File

@ -43,6 +43,7 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import builder_impl from tensorflow.python.saved_model import builder_impl
from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import function_serialization from tensorflow.python.saved_model import function_serialization
@ -742,14 +743,17 @@ def _serialize_object_graph(saveable_view, asset_file_def_index):
if serialized is not None: if serialized is not None:
proto.concrete_functions[name].CopyFrom(serialized) proto.concrete_functions[name].CopyFrom(serialized)
saved_object_metadata = False
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes): for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
_write_object_proto(obj, obj_proto, asset_file_def_index, has_saved_object_metadata = _write_object_proto(
saveable_view.function_name_map) obj, obj_proto, asset_file_def_index, saveable_view.function_name_map)
return proto saved_object_metadata = saved_object_metadata or has_saved_object_metadata
return proto, saved_object_metadata
def _write_object_proto(obj, proto, asset_file_def_index, function_name_map): def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
"""Saves an object into SavedObject proto.""" """Saves an object into SavedObject proto."""
has_saved_object_metadata = False # The metadata field will be deprecated.
if isinstance(obj, tracking.Asset): if isinstance(obj, tracking.Asset):
proto.asset.SetInParent() proto.asset.SetInParent()
proto.asset.asset_file_def_index = asset_file_def_index[obj] proto.asset.asset_file_def_index = asset_file_def_index[obj]
@ -785,11 +789,14 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
if registered_type_proto is None: if registered_type_proto is None:
# Fallback for types with no matching registration # Fallback for types with no matching registration
# pylint:disable=protected-access # pylint:disable=protected-access
metadata = obj._tracking_metadata
if metadata:
has_saved_object_metadata = True
registered_type_proto = saved_object_graph_pb2.SavedUserObject( registered_type_proto = saved_object_graph_pb2.SavedUserObject(
identifier=obj._object_identifier, identifier=obj._object_identifier,
version=versions_pb2.VersionDef( version=versions_pb2.VersionDef(
producer=1, min_consumer=1, bad_consumers=[]), producer=1, min_consumer=1, bad_consumers=[]),
metadata=obj._tracking_metadata) metadata=metadata)
# pylint:enable=protected-access # pylint:enable=protected-access
proto.user_object.CopyFrom(registered_type_proto) proto.user_object.CopyFrom(registered_type_proto)
@ -802,6 +809,7 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
# documentation. # documentation.
if hasattr(obj, "_write_object_proto"): if hasattr(obj, "_write_object_proto"):
obj._write_object_proto(proto, options) # pylint: disable=protected-access obj._write_object_proto(proto, options) # pylint: disable=protected-access
return has_saved_object_metadata
def _export_debug_info(exported_graph, export_dir): def _export_debug_info(exported_graph, export_dir):
@ -999,8 +1007,7 @@ def save(obj, export_dir, signatures=None, options=None):
instances with input signatures or concrete functions. Keys of such a instances with input signatures or concrete functions. Keys of such a
dictionary may be arbitrary strings, but will typically be from the dictionary may be arbitrary strings, but will typically be from the
`tf.saved_model.signature_constants` module. `tf.saved_model.signature_constants` module.
options: Optional, `tf.saved_model.SaveOptions` object that specifies options: `tf.saved_model.SaveOptions` object for configuring save options.
options for saving.
Raises: Raises:
ValueError: If `obj` is not trackable. ValueError: If `obj` is not trackable.
@ -1151,11 +1158,27 @@ def _build_meta_graph_impl(obj,
for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access
function_aliases[fdef.name] = alias function_aliases[fdef.name] = alias
object_graph_proto = _serialize_object_graph(saveable_view, object_graph_proto, saved_object_metadata = _serialize_object_graph(
asset_info.asset_index) saveable_view, asset_info.asset_index)
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto) meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
return meta_graph_def, exported_graph, object_saver, asset_info if saved_object_metadata:
tf_logging.warn(
'FOR KERAS USERS: The object that you are saving contains one or more '
'Keras models or layers. If you are loading the SavedModel with '
'`tf.keras.models.load_model`, continue reading (otherwise, you may '
'ignore the following instructions). Please change your code to save '
'with `tf.keras.models.save_model` or `model.save`, and confirm that '
'the file "keras.metadata" exists in the export directory. In the '
'future, Keras will only load the SavedModels that have this file. In '
'other words, `tf.saved_model.save` will no longer write SavedModels '
'that can be recovered as Keras models (this will apply in TF 2.5).'
'\n\nFOR DEVS: If you are overwriting _tracking_metadata in your class,'
' this property has been used to save metadata in the SavedModel. The '
'metadta field will be deprecated soon, so please move the metadata to '
'a different file.')
return (meta_graph_def, exported_graph, object_saver, asset_info)
def _build_meta_graph(obj, def _build_meta_graph(obj,