Merge pull request #45390 from k-w-w/r2.4
Revert recent SavedModel changes
This commit is contained in:
commit
97c3fef64b
@ -76,9 +76,6 @@ message SavedUserObject {
|
||||
string identifier = 1;
|
||||
// Version information from the producer of this SavedUserObject.
|
||||
VersionDef version = 2;
|
||||
// Deprecated! At the time of deprecation, Keras was the only user of this
|
||||
// field, and its saving and loading code will be updated shortly.
|
||||
// Please save your application-specific metadata to separate file
|
||||
// Initialization-related metadata.
|
||||
string metadata = 3;
|
||||
}
|
||||
|
@ -49,7 +49,6 @@ py_library(
|
||||
deps = [
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:saver",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
|
@ -26,7 +26,3 @@ KERAS_ATTR = 'keras_api'
|
||||
# Keys for the serialization cache.
|
||||
# Maps to the keras serialization dict {Layer --> SerializedAttributes object}
|
||||
KERAS_CACHE_KEY = 'keras_serialized_attributes'
|
||||
|
||||
|
||||
# Name of Keras metadata file stored in the SavedModel.
|
||||
SAVED_METADATA_PATH = 'keras_metadata.pb'
|
||||
|
@ -17,12 +17,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import types
|
||||
|
||||
from google.protobuf import message
|
||||
|
||||
from tensorflow.core.framework import versions_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function as defun
|
||||
@ -41,7 +38,6 @@ from tensorflow.python.keras.saving.saved_model.serialized_attributes import Com
|
||||
from tensorflow.python.keras.utils import generic_utils
|
||||
from tensorflow.python.keras.utils import metrics_utils
|
||||
from tensorflow.python.keras.utils.generic_utils import LazyLoader
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import load as tf_load
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
@ -125,26 +121,13 @@ def load(path, compile=True, options=None): # pylint: disable=redefined-builtin
|
||||
# TODO(kathywu): Add saving/loading of optimizer, compiled losses and metrics.
|
||||
# TODO(kathywu): Add code to load from objects that contain all endpoints
|
||||
|
||||
# Look for metadata file or parse the SavedModel
|
||||
# The Keras metadata file is not yet saved, so create it from the SavedModel.
|
||||
metadata = saved_metadata_pb2.SavedMetadata()
|
||||
meta_graph_def = loader_impl.parse_saved_model(path).meta_graphs[0]
|
||||
object_graph_def = meta_graph_def.object_graph_def
|
||||
path_to_metadata_pb = os.path.join(path, constants.SAVED_METADATA_PATH)
|
||||
if gfile.Exists(path_to_metadata_pb):
|
||||
try:
|
||||
with gfile.GFile(path_to_metadata_pb, 'rb') as f:
|
||||
file_content = f.read()
|
||||
metadata.ParseFromString(file_content)
|
||||
except message.DecodeError as e:
|
||||
raise IOError('Cannot parse keras metadata {}: {}.'
|
||||
.format(path_to_metadata_pb, str(e)))
|
||||
else:
|
||||
logging.warning('SavedModel saved prior to TF 2.4 detected when loading '
|
||||
'Keras model. Please ensure that you are saving the model '
|
||||
'with model.save() or tf.keras.models.save_model(), *NOT* '
|
||||
'tf.saved_model.save(). To confirm, there should be a file '
|
||||
'named "keras_metadata.pb" in the SavedModel directory.')
|
||||
_read_legacy_metadata(object_graph_def, metadata)
|
||||
# TODO(kathywu): When the keras metadata file is saved, load it directly
|
||||
# instead of calling the _read_legacy_metadata function.
|
||||
_read_legacy_metadata(object_graph_def, metadata)
|
||||
|
||||
if not metadata.nodes:
|
||||
# When there are no Keras objects, return the results from the core loader
|
||||
|
@ -18,21 +18,15 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from tensorflow.core.framework import versions_pb2
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.keras import backend as K
|
||||
from tensorflow.python.keras.protobuf import saved_metadata_pb2
|
||||
from tensorflow.python.keras.saving import saving_utils
|
||||
from tensorflow.python.keras.saving.saved_model import constants
|
||||
from tensorflow.python.keras.saving.saved_model import save_impl
|
||||
from tensorflow.python.keras.saving.saved_model import utils
|
||||
from tensorflow.python.keras.utils.generic_utils import LazyLoader
|
||||
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.saved_model import save as save_lib
|
||||
|
||||
|
||||
# To avoid circular dependencies between keras/engine and keras/saving,
|
||||
# code in keras/saving must delay imports.
|
||||
|
||||
@ -92,39 +86,7 @@ def save(model, filepath, overwrite, include_optimizer, signatures=None,
|
||||
# we use the default replica context here.
|
||||
with distribution_strategy_context._get_default_replica_context(): # pylint: disable=protected-access
|
||||
with utils.keras_option_scope(save_traces):
|
||||
saved_nodes, node_paths = save_lib.save_and_return_nodes(
|
||||
model, filepath, signatures, options)
|
||||
|
||||
# Save all metadata to a separate file in the SavedModel directory.
|
||||
metadata = generate_keras_metadata(saved_nodes, node_paths)
|
||||
|
||||
with gfile.GFile(
|
||||
os.path.join(filepath, constants.SAVED_METADATA_PATH), "wb") as w:
|
||||
w.write(metadata.SerializeToString(deterministic=True))
|
||||
save_lib.save(model, filepath, signatures, options)
|
||||
|
||||
if not include_optimizer:
|
||||
model.optimizer = orig_optimizer
|
||||
|
||||
|
||||
def generate_keras_metadata(saved_nodes, node_paths):
|
||||
"""Constructs a KerasMetadata proto with the metadata of each keras object."""
|
||||
metadata = saved_metadata_pb2.SavedMetadata()
|
||||
|
||||
for node_id, node in enumerate(saved_nodes):
|
||||
if isinstance(node, base_layer.Layer):
|
||||
path = node_paths[node]
|
||||
if not path:
|
||||
node_path = "root"
|
||||
else:
|
||||
node_path = "root.{}".format(
|
||||
".".join([ref.name for ref in path]))
|
||||
|
||||
metadata.nodes.add(
|
||||
node_id=node_id,
|
||||
node_path=node_path,
|
||||
version=versions_pb2.VersionDef(
|
||||
producer=1, min_consumer=1, bad_consumers=[]),
|
||||
identifier=node._object_identifier, # pylint: disable=protected-access
|
||||
metadata=node._tracking_metadata) # pylint: disable=protected-access
|
||||
|
||||
return metadata
|
||||
|
@ -26,6 +26,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
@ -410,16 +411,14 @@ class TestSavedModelFormatAllModes(keras_parameterized.TestCase):
|
||||
self.evaluate(variables.variables_initializer(model.variables))
|
||||
saved_model_dir = self._save_model_dir()
|
||||
|
||||
# TODO(kathywu): Re-enable this check after removing the tf.saved_model.save
|
||||
# metadata warning.
|
||||
# with self.captureWritesToStream(sys.stderr) as captured_logs:
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
with self.captureWritesToStream(sys.stderr) as captured_logs:
|
||||
model.save(saved_model_dir, save_format='tf')
|
||||
loaded = keras_load.load(saved_model_dir)
|
||||
|
||||
# Assert that saving does not log deprecation warnings
|
||||
# (even if it needs to set learning phase for compat reasons)
|
||||
# if context.executing_eagerly():
|
||||
# self.assertNotIn('deprecated', captured_logs.contents())
|
||||
if context.executing_eagerly():
|
||||
self.assertNotIn('deprecated', captured_logs.contents())
|
||||
|
||||
input_arr = array_ops.constant([[11], [12], [13]], dtype=dtypes.float32)
|
||||
input_arr2 = array_ops.constant([[14], [15], [16]], dtype=dtypes.float32)
|
||||
|
@ -349,7 +349,6 @@ py_strict_library(
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:tf_export",
|
||||
|
@ -43,7 +43,6 @@ from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.saved_model import builder_impl
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import function_serialization
|
||||
@ -183,9 +182,8 @@ class _SaveableView(object):
|
||||
"""
|
||||
self.options = options
|
||||
self.checkpoint_view = checkpoint_view
|
||||
trackable_objects, path_to_root, node_ids, slot_variables = (
|
||||
self.checkpoint_view.objects_ids_and_slot_variables_and_paths())
|
||||
self.node_paths = path_to_root
|
||||
trackable_objects, node_ids, slot_variables = (
|
||||
self.checkpoint_view.objects_ids_and_slot_variables())
|
||||
self.nodes = trackable_objects
|
||||
self.node_ids = node_ids
|
||||
self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
|
||||
@ -752,17 +750,14 @@ def _serialize_object_graph(saveable_view, asset_file_def_index):
|
||||
if serialized is not None:
|
||||
proto.concrete_functions[name].CopyFrom(serialized)
|
||||
|
||||
saved_object_metadata = False
|
||||
for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
|
||||
has_saved_object_metadata = _write_object_proto(
|
||||
obj, obj_proto, asset_file_def_index, saveable_view.function_name_map)
|
||||
saved_object_metadata = saved_object_metadata or has_saved_object_metadata
|
||||
return proto, saved_object_metadata
|
||||
_write_object_proto(obj, obj_proto, asset_file_def_index,
|
||||
saveable_view.function_name_map)
|
||||
return proto
|
||||
|
||||
|
||||
def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
|
||||
"""Saves an object into SavedObject proto."""
|
||||
has_saved_object_metadata = False # The metadata field will be deprecated.
|
||||
if isinstance(obj, tracking.Asset):
|
||||
proto.asset.SetInParent()
|
||||
proto.asset.asset_file_def_index = asset_file_def_index[obj]
|
||||
@ -798,14 +793,11 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
|
||||
if registered_type_proto is None:
|
||||
# Fallback for types with no matching registration
|
||||
# pylint:disable=protected-access
|
||||
metadata = obj._tracking_metadata
|
||||
if metadata:
|
||||
has_saved_object_metadata = True
|
||||
registered_type_proto = saved_object_graph_pb2.SavedUserObject(
|
||||
identifier=obj._object_identifier,
|
||||
version=versions_pb2.VersionDef(
|
||||
producer=1, min_consumer=1, bad_consumers=[]),
|
||||
metadata=metadata)
|
||||
metadata=obj._tracking_metadata)
|
||||
# pylint:enable=protected-access
|
||||
proto.user_object.CopyFrom(registered_type_proto)
|
||||
|
||||
@ -818,7 +810,6 @@ def _write_object_proto(obj, proto, asset_file_def_index, function_name_map):
|
||||
# documentation.
|
||||
if hasattr(obj, "_write_object_proto"):
|
||||
obj._write_object_proto(proto, options) # pylint: disable=protected-access
|
||||
return has_saved_object_metadata
|
||||
|
||||
|
||||
def _export_debug_info(exported_graph, export_dir):
|
||||
@ -1016,7 +1007,8 @@ def save(obj, export_dir, signatures=None, options=None):
|
||||
instances with input signatures or concrete functions. Keys of such a
|
||||
dictionary may be arbitrary strings, but will typically be from the
|
||||
`tf.saved_model.signature_constants` module.
|
||||
options: `tf.saved_model.SaveOptions` object for configuring save options.
|
||||
options: Optional, `tf.saved_model.SaveOptions` object that specifies
|
||||
options for saving.
|
||||
|
||||
Raises:
|
||||
ValueError: If `obj` is not trackable.
|
||||
@ -1030,30 +1022,6 @@ def save(obj, export_dir, signatures=None, options=None):
|
||||
May not be called from within a function body.
|
||||
@end_compatibility
|
||||
"""
|
||||
save_and_return_nodes(obj, export_dir, signatures, options,
|
||||
raise_metadata_warning=True)
|
||||
|
||||
|
||||
def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
|
||||
raise_metadata_warning=False):
|
||||
"""Saves a SavedModel while returning all saved nodes and their paths.
|
||||
|
||||
Please see `tf.saved_model.save` for details.
|
||||
|
||||
Args:
|
||||
obj: A trackable object to export.
|
||||
export_dir: A directory in which to write the SavedModel.
|
||||
signatures: A function or dictionary of functions to save in the SavedModel
|
||||
as signatures.
|
||||
options: `tf.saved_model.SaveOptions` object for configuring save options.
|
||||
raise_metadata_warning: Whether to raise the metadata warning. This arg will
|
||||
be removed in TF 2.5.
|
||||
|
||||
Returns:
|
||||
A tuple of (a list of saved nodes in the order they are serialized to the
|
||||
`SavedObjectGraph`, dictionary mapping nodes to one possible path from
|
||||
the root node to the key node)
|
||||
"""
|
||||
options = options or save_options.SaveOptions()
|
||||
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
|
||||
# compatible (no sessions) and share it with this export API rather than
|
||||
@ -1061,9 +1029,8 @@ def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
|
||||
saved_model = saved_model_pb2.SavedModel()
|
||||
meta_graph_def = saved_model.meta_graphs.add()
|
||||
|
||||
_, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
|
||||
_build_meta_graph(obj, signatures, options, meta_graph_def,
|
||||
raise_metadata_warning))
|
||||
_, exported_graph, object_saver, asset_info = _build_meta_graph(
|
||||
obj, signatures, options, meta_graph_def)
|
||||
saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
|
||||
|
||||
# Write the checkpoint, copy assets into the assets directory, and write out
|
||||
@ -1103,8 +1070,6 @@ def save_and_return_nodes(obj, export_dir, signatures=None, options=None,
|
||||
# constants in the saved graph.
|
||||
ops.dismantle_graph(exported_graph)
|
||||
|
||||
return saved_nodes, node_paths
|
||||
|
||||
|
||||
def export_meta_graph(obj, filename, signatures=None, options=None):
|
||||
"""Exports the MetaGraph proto of the `obj` to a file.
|
||||
@ -1131,7 +1096,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None):
|
||||
"""
|
||||
options = options or save_options.SaveOptions()
|
||||
export_dir = os.path.dirname(filename)
|
||||
meta_graph_def, exported_graph, _, _, _, _ = _build_meta_graph(
|
||||
meta_graph_def, exported_graph, _, _ = _build_meta_graph(
|
||||
obj, signatures, options)
|
||||
|
||||
file_io.atomic_write_string_to_file(
|
||||
@ -1150,8 +1115,7 @@ def export_meta_graph(obj, filename, signatures=None, options=None):
|
||||
def _build_meta_graph_impl(obj,
|
||||
signatures,
|
||||
options,
|
||||
meta_graph_def=None,
|
||||
raise_metadata_warning=True):
|
||||
meta_graph_def=None):
|
||||
"""Creates a MetaGraph containing the resources and functions of an object."""
|
||||
if ops.inside_function():
|
||||
raise AssertionError(
|
||||
@ -1195,35 +1159,17 @@ def _build_meta_graph_impl(obj,
|
||||
for fdef in func._stateless_fn._function_cache.all_values(): # pylint: disable=protected-access
|
||||
function_aliases[fdef.name] = alias
|
||||
|
||||
object_graph_proto, saved_object_metadata = _serialize_object_graph(
|
||||
saveable_view, asset_info.asset_index)
|
||||
object_graph_proto = _serialize_object_graph(saveable_view,
|
||||
asset_info.asset_index)
|
||||
meta_graph_def.object_graph_def.CopyFrom(object_graph_proto)
|
||||
|
||||
if saved_object_metadata and raise_metadata_warning:
|
||||
tf_logging.warn(
|
||||
'FOR KERAS USERS: The object that you are saving contains one or more '
|
||||
'Keras models or layers. If you are loading the SavedModel with '
|
||||
'`tf.keras.models.load_model`, continue reading (otherwise, you may '
|
||||
'ignore the following instructions). Please change your code to save '
|
||||
'with `tf.keras.models.save_model` or `model.save`, and confirm that '
|
||||
'the file "keras.metadata" exists in the export directory. In the '
|
||||
'future, Keras will only load the SavedModels that have this file. In '
|
||||
'other words, `tf.saved_model.save` will no longer write SavedModels '
|
||||
'that can be recovered as Keras models (this will apply in TF 2.5).'
|
||||
'\n\nFOR DEVS: If you are overwriting _tracking_metadata in your class,'
|
||||
' this property has been used to save metadata in the SavedModel. The '
|
||||
'metadta field will be deprecated soon, so please move the metadata to '
|
||||
'a different file.')
|
||||
|
||||
return (meta_graph_def, exported_graph, object_saver, asset_info,
|
||||
saveable_view.nodes, saveable_view.node_paths)
|
||||
return meta_graph_def, exported_graph, object_saver, asset_info
|
||||
|
||||
|
||||
def _build_meta_graph(obj,
|
||||
signatures,
|
||||
options,
|
||||
meta_graph_def=None,
|
||||
raise_metadata_warning=True):
|
||||
meta_graph_def=None):
|
||||
"""Creates a MetaGraph under a save context.
|
||||
|
||||
Args:
|
||||
@ -1236,8 +1182,6 @@ def _build_meta_graph(obj,
|
||||
options: `tf.saved_model.SaveOptions` object that specifies options for
|
||||
saving.
|
||||
meta_graph_def: Optional, the MetaGraphDef proto fill.
|
||||
raise_metadata_warning: Whether to raise a warning when user objects contain
|
||||
non-empty metadata.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `export_meta_graph` is executing inside a `tf.function`.
|
||||
@ -1251,5 +1195,4 @@ def _build_meta_graph(obj,
|
||||
"""
|
||||
|
||||
with save_context.save_context(options):
|
||||
return _build_meta_graph_impl(obj, signatures, options, meta_graph_def,
|
||||
raise_metadata_warning)
|
||||
return _build_meta_graph_impl(obj, signatures, options, meta_graph_def)
|
||||
|
@ -430,7 +430,7 @@ class ObjectGraphView(object):
|
||||
name=base.OBJECT_GRAPH_PROTO_KEY))
|
||||
return named_saveable_objects
|
||||
|
||||
def objects_ids_and_slot_variables_and_paths(self):
|
||||
def objects_ids_and_slot_variables(self):
|
||||
"""Traverse the object graph and list all accessible objects.
|
||||
|
||||
Looks for `Trackable` objects which are dependencies of
|
||||
@ -439,8 +439,7 @@ class ObjectGraphView(object):
|
||||
(i.e. if they would be saved with a checkpoint).
|
||||
|
||||
Returns:
|
||||
A tuple of (trackable objects, paths from root for each object,
|
||||
object -> node id, slot variables)
|
||||
A tuple of (trackable objects, object -> node id, slot variables)
|
||||
"""
|
||||
trackable_objects, path_to_root = self._breadth_first_traversal()
|
||||
object_names = object_identity.ObjectIdentityDictionary()
|
||||
@ -453,11 +452,6 @@ class ObjectGraphView(object):
|
||||
trackable_objects=trackable_objects,
|
||||
node_ids=node_ids,
|
||||
object_names=object_names)
|
||||
return trackable_objects, path_to_root, node_ids, slot_variables
|
||||
|
||||
def objects_ids_and_slot_variables(self):
|
||||
trackable_objects, _, node_ids, slot_variables = (
|
||||
self.objects_ids_and_slot_variables_and_paths())
|
||||
return trackable_objects, node_ids, slot_variables
|
||||
|
||||
def list_objects(self):
|
||||
|
Loading…
Reference in New Issue
Block a user