Warn users when saving SavedModel with metadata.
The metadata field will no longer be used by Keras. Since Keras is the only consumer metadata field, this field will be deprecated shortly. PiperOrigin-RevId: 338353130 Change-Id: I762b7b223255966c78b5b362b0d07ec27351bb42
This commit is contained in:
parent
201500c78d
commit
0ed710fb76
tensorflow
core/protobuf
python
@ -76,6 +76,9 @@ message SavedUserObject {
|
|||||||
string identifier = 1;
|
string identifier = 1;
|
||||||
// Version information from the producer of this SavedUserObject.
|
// Version information from the producer of this SavedUserObject.
|
||||||
VersionDef version = 2;
|
VersionDef version = 2;
|
||||||
|
// Deprecated! At the time of deprecation, Keras was the only user of this
|
||||||
|
// field, and its saving and loading code will be updated shortly.
|
||||||
|
// Please save your application-specific metadata to separate file
|
||||||
// Initialization-related metadata.
|
// Initialization-related metadata.
|
||||||
string metadata = 3;
|
string metadata = 3;
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -411,14 +410,16 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
|||||||
self.evaluate(variables.variables_initializer(model.variables))
|
self.evaluate(variables.variables_initializer(model.variables))
|
||||||
saved_model_dir = self._save_model_dir()
|
saved_model_dir = self._save_model_dir()
|
||||||
|
|
||||||
with self.captureWritesToStream(sys.stderr) as captured_logs:
|
# TODO(kathywu): Re-enable this check after removing the tf.saved_model.save
|
||||||
model.save(saved_model_dir, save_format='tf')
|
# metadata warning.
|
||||||
loaded = keras_load.load(saved_model_dir)
|
# with self.captureWritesToStream(sys.stderr) as captured_logs:
|
||||||
|
model.save(saved_model_dir, save_format='tf')
|
||||||
|
loaded = keras_load.load(saved_model_dir)
|
||||||
|
|
||||||
# Assert that saving does not log deprecation warnings
|
# Assert that saving does not log deprecation warnings
|
||||||
# (even if it needs to set learning phase for compat reasons)
|
# (even if it needs to set learning phase for compat reasons)
|
||||||
if context.executing_eagerly():
|
# if context.executing_eagerly():
|
||||||
self.assertNotIn('deprecated', captured_logs.contents())
|
# self.assertNotIn('deprecated', captured_logs.contents())
|
||||||
|
|
||||||
input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32)
|
input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32)
|
||||||
input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32)
|
input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32)
|
||||||
|
@ -349,6 +349,7 @@ py_strict_library(
|
|||||||
"//tensorflow/python:framework",
|
"//tensorflow/python:framework",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
"//tensorflow/python:lib",
|
"//tensorflow/python:lib",
|
||||||
|
"//tensorflow/python:platform",
|
||||||
"//tensorflow/python:resource_variable_ops",
|
"//tensorflow/python:resource_variable_ops",
|
||||||
"//tensorflow/python:tensor_util",
|
"//tensorflow/python:tensor_util",
|
||||||
"//tensorflow/python:tf_export",
|
"//tensorflow/python:tf_export",
|
||||||
|
@ -43,6 +43,7 @@ from tensorflow.python.lib.io import file_io
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
|
from tensorflow.python.platform import tf_logging
|
||||||
from tensorflow.python.saved_model import builder_impl
|
from tensorflow.python.saved_model import builder_impl
|
||||||
from tensorflow.python.saved_model import constants
|
from tensorflow.python.saved_model import constants
|
||||||
from tensorflow.python.saved_model import function_serialization
|
from tensorflow.python.saved_model import function_serialization
|
||||||
@ -742,14 +743,17 @@ def _serialize_object_graph(saveable_view, asset_file_def_index):
|
|||||||
if serialized is not None:
|
if serialized is not None:
|
||||||
proto.concrete_functions[name].CopyFrom(serialized)
|
proto.concrete_functions[name].CopyFrom(serialized)
|
||||||
|
|
||||||
|
saved_object_metadata = False
|
||||||
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
|
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
|
||||||
_write_object_proto(obj, obj_proto, asset_file_def_index,
|
has_saved_object_metadata = _write_object_proto(
|
||||||
saveable_view.function_name_map)
|
obj, obj_proto, asset_file_def_index, saveable_view.function_name_map)
|
||||||
return proto
|
saved_object_metadata = saved_object_metadata or has_saved_object_metadata
|
||||||
|
return proto, saved_object_metadata
|
||||||
|
|
||||||
|
|
||||||
def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
|
def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
|
||||||
"""Saves an object into SavedObject proto."""
|
"""Saves an object into SavedObject proto."""
|
||||||
|
has_saved_object_metadata = False # The metadata field will be deprecated.
|
||||||
if isinstance(obj, tracking.Asset):
|
if isinstance(obj, tracking.Asset):
|
||||||
proto.asset.SetInParent()
|
proto.asset.SetInParent()
|
||||||
proto.asset.asset_file_def_index = asset_file_def_index[obj]
|
proto.asset.asset_file_def_index = asset_file_def_index[obj]
|
||||||
@ -785,11 +789,14 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
|
|||||||
if registered_type_proto is None:
|
if registered_type_proto is None:
|
||||||
# Fallback for types with no matching registration
|
# Fallback for types with no matching registration
|
||||||
# pylint:disable=protected-access
|
# pylint:disable=protected-access
|
||||||
|
metadata = obj._tracking_metadata
|
||||||
|
if metadata:
|
||||||
|
has_saved_object_metadata = True
|
||||||
registered_type_proto = saved_object_graph_pb2.SavedUserObject(
|
registered_type_proto = saved_object_graph_pb2.SavedUserObject(
|
||||||
identifier=obj._object_identifier,
|
identifier=obj._object_identifier,
|
||||||
version=versions_pb2.VersionDef(
|
version=versions_pb2.VersionDef(
|
||||||
producer=1, min_consumer=1, bad_consumers=[]),
|
producer=1, min_consumer=1, bad_consumers=[]),
|
||||||
metadata=obj._tracking_metadata)
|
metadata=metadata)
|
||||||
# pylint:enable=protected-access
|
# pylint:enable=protected-access
|
||||||
proto.user_object.CopyFrom(registered_type_proto)
|
proto.user_object.CopyFrom(registered_type_proto)
|
||||||
|
|
||||||
@ -802,6 +809,7 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
|
|||||||
# documentation.
|
# documentation.
|
||||||
if hasattr(obj, "_write_object_proto"):
|
if hasattr(obj, "_write_object_proto"):
|
||||||
obj._write_object_proto(proto, options) # pylint: disable=protected-access
|
obj._write_object_proto(proto, options) # pylint: disable=protected-access
|
||||||
|
return has_saved_object_metadata
|
||||||
|
|
||||||
|
|
||||||
def _export_debug_info(exported_graph, export_dir):
|
def _export_debug_info(exported_graph, export_dir):
|
||||||
@ -999,8 +1007,7 @@ def save(obj, export_dir, signatures=None, options=None):
|
|||||||
instances with input signatures or concrete functions. Keys of such a
|
instances with input signatures or concrete functions. Keys of such a
|
||||||
dictionary may be arbitrary strings, but will typically be from the
|
dictionary may be arbitrary strings, but will typically be from the
|
||||||
`tf.saved_model.signature_constants` module.
|
`tf.saved_model.signature_constants` module.
|
||||||
options: Optional, `tf.saved_model.SaveOptions` object that specifies
|
options: `tf.saved_model.SaveOptions` object for configuring save options.
|
||||||
options for saving.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If `obj` is not trackable.
|
ValueError: If `obj` is not trackable.
|
||||||
@ -1151,11 +1158,27 @@ def _build_meta_graph_impl(obj,
|
|||||||
for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access
|
for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access
|
||||||
function_aliases[fdef.name] = alias
|
function_aliases[fdef.name] = alias
|
||||||
|
|
||||||
object_graph_proto = _serialize_object_graph(saveable_view,
|
object_graph_proto, saved_object_metadata = _serialize_object_graph(
|
||||||
asset_info.asset_index)
|
saveable_view, asset_info.asset_index)
|
||||||
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
|
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
|
||||||
|
|
||||||
return meta_graph_def, exported_graph, object_saver, asset_info
|
if saved_object_metadata:
|
||||||
|
tf_logging.warn(
|
||||||
|
'FOR KERAS USERS: The object that you are saving contains one or more '
|
||||||
|
'Keras models or layers. If you are loading the SavedModel with '
|
||||||
|
'`tf.keras.models.load_model`, continue reading (otherwise, you may '
|
||||||
|
'ignore the following instructions). Please change your code to save '
|
||||||
|
'with `tf.keras.models.save_model` or `model.save`, and confirm that '
|
||||||
|
'the file "keras.metadata" exists in the export directory. In the '
|
||||||
|
'future, Keras will only load the SavedModels that have this file. In '
|
||||||
|
'other words, `tf.saved_model.save` will no longer write SavedModels '
|
||||||
|
'that can be recovered as Keras models (this will apply in TF 2.5).'
|
||||||
|
'\n\nFOR DEVS: If you are overwriting _tracking_metadata in your class,'
|
||||||
|
' this property has been used to save metadata in the SavedModel. The '
|
||||||
|
'metadta field will be deprecated soon, so please move the metadata to '
|
||||||
|
'a different file.')
|
||||||
|
|
||||||
|
return (meta_graph_def, exported_graph, object_saver, asset_info)
|
||||||
|
|
||||||
|
|
||||||
def _build_meta_graph(obj,
|
def _build_meta_graph(obj,
|
||||||
|
Loading…
Reference in New Issue
Block a user