Save main_op and train_op in SignatureDefs instead of collections.

PiperOrigin-RevId: 223268983
This commit is contained in:
Katherine Wu 2018-11-28 17:55:17 -08:00 committed by TensorFlower Gardener
parent 99c20bf32e
commit 693c1ad4a9
21 changed files with 583 additions and 365 deletions

View File

@ -133,5 +133,6 @@ filegroup(
"testdata/half_plus_two_pbtxt/**", "testdata/half_plus_two_pbtxt/**",
"testdata/half_plus_two_main_op/**", "testdata/half_plus_two_main_op/**",
"testdata/half_plus_two/**", "testdata/half_plus_two/**",
"testdata/half_plus_two_v2/**",
]), ]),
) )

View File

@ -33,10 +33,10 @@ constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
/// SavedModel text format proto filename. /// SavedModel text format proto filename.
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt"; constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
/// SavedModel legacy init op key. /// SavedModel legacy init op collection key. Used in v1 SavedModels.
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op"; constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
/// SavedModel main op key. /// SavedModel main op collection key. Used in v1 SavedModels.
constexpr char kSavedModelMainOpKey[] = "saved_model_main_op"; constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
/// Directory in which to save the SavedModel variables. /// Directory in which to save the SavedModel variables.
@ -45,6 +45,11 @@ constexpr char kSavedModelVariablesDirectory[] = "variables";
/// SavedModel variables filename. /// SavedModel variables filename.
constexpr char kSavedModelVariablesFilename[] = "variables"; constexpr char kSavedModelVariablesFilename[] = "variables";
/// SavedModel SignatureDef keys for the initialization and train ops. Used in
/// V2 SavedModels.
constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op";
constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op";
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ #endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_

View File

@ -122,38 +122,58 @@ Status RunOnce(const RunOptions& run_options,
return run_status; return run_status;
} }
bool HasMainOp(const MetaGraphDef& meta_graph_def) { // RunInitOp will return OK if the initialization op was run successfully.
const auto& collection_def_map = meta_graph_def.collection_def(); // An empty init_op_name indicates that there are no init ops to run.
if (collection_def_map.find(kSavedModelMainOpKey) != Status RunInitOp(const RunOptions& run_options, const string& export_dir,
collection_def_map.end()) {
return true;
}
return false;
}
Status RunMainOp(const RunOptions& run_options, const string& export_dir,
const MetaGraphDef& meta_graph_def, const MetaGraphDef& meta_graph_def,
const std::vector<AssetFileDef>& asset_file_defs, const std::vector<AssetFileDef>& asset_file_defs,
Session* session, const string& main_op_key) { Session* session, const string& init_op_name) {
LOG(INFO) << "Running MainOp with key " << main_op_key if (!init_op_name.empty()) {
<< " on SavedModel bundle."; LOG(INFO) << "Running initialization op on SavedModel bundle.";
const auto& collection_def_map = meta_graph_def.collection_def();
const auto main_op_it = collection_def_map.find(main_op_key);
if (main_op_it != collection_def_map.end()) {
if (main_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
std::vector<std::pair<string, Tensor>> inputs; std::vector<std::pair<string, Tensor>> inputs;
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs); AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
RunMetadata run_metadata; RunMetadata run_metadata;
const StringPiece main_op_name = main_op_it->second.node_list().value(0); return RunOnce(run_options, inputs, {}, {init_op_name},
return RunOnce(run_options, inputs, {}, {string(main_op_name)},
nullptr /* outputs */, &run_metadata, session); nullptr /* outputs */, &run_metadata, session);
} }
return Status::OK(); return Status::OK();
} }
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status RunRestore(const RunOptions& run_options, const string& export_dir, Status RunRestore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name, const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name, const StringPiece variable_filename_const_op_name,
@ -236,15 +256,12 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
bundle->meta_graph_def.saver_def().restore_op_name(), bundle->meta_graph_def.saver_def().restore_op_name(),
bundle->meta_graph_def.saver_def().filename_tensor_name(), bundle->meta_graph_def.saver_def().filename_tensor_name(),
asset_file_defs, bundle->session.get())); asset_file_defs, bundle->session.get()));
if (HasMainOp(bundle->meta_graph_def)) { string init_op_name;
TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir, TF_RETURN_IF_ERROR(
bundle->meta_graph_def, asset_file_defs, GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
bundle->session.get(), kSavedModelMainOpKey)); TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
} else { asset_file_defs, bundle->session.get(),
TF_RETURN_IF_ERROR(RunMainOp( init_op_name));
run_options, export_dir, bundle->meta_graph_def, asset_file_defs,
bundle->session.get(), kSavedModelLegacyInitOpKey));
}
return Status::OK(); return Status::OK();
} }

View File

@ -36,6 +36,8 @@ constexpr char kTestDataMainOp[] =
"cc/saved_model/testdata/half_plus_two_main_op/00000123"; "cc/saved_model/testdata/half_plus_two_main_op/00000123";
constexpr char kTestDataSharded[] = constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two/00000123"; "cc/saved_model/testdata/half_plus_two/00000123";
constexpr char kTestDataInitOpV2[] =
"cc/saved_model/testdata/half_plus_two_v2/00000123";
class LoaderTest : public ::testing::Test { class LoaderTest : public ::testing::Test {
protected: protected:
@ -227,5 +229,17 @@ TEST_F(LoaderTest, MaybeSavedModelDirectory) {
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir)); EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
} }
TEST_F(LoaderTest, SavedModelInitOpV2Format) {
SavedModelBundle bundle;
SessionOptions session_options;
RunOptions run_options;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2);
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
{kSavedModelTagServe}, &bundle));
CheckSavedModelBundle(export_dir, bundle);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1 @@
asset-file-contents

View File

@ -125,7 +125,7 @@ def save_keras_model(
export_dir = export_helpers.get_timestamped_export_dir(saved_model_path) export_dir = export_helpers.get_timestamped_export_dir(saved_model_path)
temp_export_dir = export_helpers.get_temp_export_dir(export_dir) temp_export_dir = export_helpers.get_temp_export_dir(export_dir)
builder = saved_model_builder.SavedModelBuilder(temp_export_dir) builder = saved_model_builder._SavedModelBuilder(temp_export_dir)
# Manually save variables to export them in an object-based checkpoint. This # Manually save variables to export them in an object-based checkpoint. This
# skips the `builder.add_meta_graph_and_variables()` step, which saves a # skips the `builder.add_meta_graph_and_variables()` step, which saves a
@ -227,9 +227,10 @@ def _export_mode(
g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations) g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
# Extract update and train ops from train/test/predict functions. # Extract update and train ops from train/test/predict functions.
train_op = None
if mode == model_fn_lib.ModeKeys.TRAIN: if mode == model_fn_lib.ModeKeys.TRAIN:
clone._make_train_function() clone._make_train_function()
builder._add_train_op(clone.train_function.updates_op) train_op = clone.train_function.updates_op
elif mode == model_fn_lib.ModeKeys.EVAL: elif mode == model_fn_lib.ModeKeys.EVAL:
clone._make_test_function() clone._make_test_function()
else: else:
@ -264,7 +265,8 @@ def _export_mode(
model_fn_lib.EXPORT_TAG_MAP[mode], model_fn_lib.EXPORT_TAG_MAP[mode],
signature_def_map=_create_signature_def_map(clone, mode), signature_def_map=_create_signature_def_map(clone, mode),
saver=saver_lib.Saver(clone_var_list), saver=saver_lib.Saver(clone_var_list),
main_op=variables.local_variables_initializer()) init_op=variables.local_variables_initializer(),
train_op=train_op)
return None return None

View File

@ -35,7 +35,6 @@ from tensorflow.python.keras.engine import training
from tensorflow.python.keras.utils import tf_utils from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import training as training_module from tensorflow.python.training import training as training_module
@ -254,7 +253,7 @@ def load_model(sess, path, mode):
outputs = { outputs = {
k: sess.graph.get_tensor_by_name(v.name) k: sess.graph.get_tensor_by_name(v.name)
for k, v in meta_graph_def.signature_def[sig_def_key].outputs.items()} for k, v in meta_graph_def.signature_def[sig_def_key].outputs.items()}
return inputs, outputs return inputs, outputs, meta_graph_def
@test_util.run_all_in_graph_and_eager_modes @test_util.run_all_in_graph_and_eager_modes
@ -331,7 +330,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
# Load predict graph, and test predictions # Load predict graph, and test predictions
with session.Session(graph=ops.Graph()) as sess: with session.Session(graph=ops.Graph()) as sess:
inputs, outputs = load_model(sess, output_path, inputs, outputs, _ = load_model(sess, output_path,
model_fn_lib.ModeKeys.PREDICT) model_fn_lib.ModeKeys.PREDICT)
predictions = sess.run(outputs[output_name], predictions = sess.run(outputs[output_name],
@ -341,7 +340,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
if optimizer: if optimizer:
# Load eval graph, and test predictions, loss and metric values # Load eval graph, and test predictions, loss and metric values
with session.Session(graph=ops.Graph()) as sess: with session.Session(graph=ops.Graph()) as sess:
inputs, outputs = load_model(sess, output_path, inputs, outputs, _ = load_model(sess, output_path,
model_fn_lib.ModeKeys.EVAL) model_fn_lib.ModeKeys.EVAL)
# First obtain the loss and predictions, and run the metric update op by # First obtain the loss and predictions, and run the metric update op by
@ -365,8 +364,8 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
# Load train graph, and check for the train op, and prediction values # Load train graph, and check for the train op, and prediction values
with session.Session(graph=ops.Graph()) as sess: with session.Session(graph=ops.Graph()) as sess:
inputs, outputs = load_model(sess, output_path, inputs, outputs, meta_graph_def = load_model(
model_fn_lib.ModeKeys.TRAIN) sess, output_path, model_fn_lib.ModeKeys.TRAIN)
self.assertEqual(int(train_before_export), self.assertEqual(int(train_before_export),
sess.run(training_module.get_global_step())) sess.run(training_module.get_global_step()))
self.assertIn('loss', outputs) self.assertIn('loss', outputs)
@ -375,7 +374,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
self.assertIn('predictions/' + output_name, outputs) self.assertIn('predictions/' + output_name, outputs)
# Train for a step # Train for a step
train_op = ops.get_collection(constants.TRAIN_OP_KEY) train_op = loader_impl.get_train_op(meta_graph_def)
train_outputs, _ = sess.run( train_outputs, _ = sess.run(
[outputs, train_op], {inputs[input_name]: input_arr, [outputs, train_op], {inputs[input_name]: input_arr,
inputs[target_name]: target_arr}) inputs[target_name]: target_arr})
@ -402,7 +401,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
output_path = keras_saved_model.save_keras_model( output_path = keras_saved_model.save_keras_model(
model, saved_model_path, custom_objects={'relu6': relu6}) model, saved_model_path, custom_objects={'relu6': relu6})
with session.Session(graph=ops.Graph()) as sess: with session.Session(graph=ops.Graph()) as sess:
inputs, outputs = load_model(sess, output_path, inputs, outputs, _ = load_model(sess, output_path,
model_fn_lib.ModeKeys.PREDICT) model_fn_lib.ModeKeys.PREDICT)
input_name = model.input_names[0] input_name = model.input_names[0]
output_name = model.output_names[0] output_name = model.output_names[0]

View File

@ -83,6 +83,7 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":constants", ":constants",
":signature_def_utils",
":utils", ":utils",
"//tensorflow/core:protos_all_py", "//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_for_generated_wrappers",
@ -114,6 +115,7 @@ py_test(
"//tensorflow/python:state_ops", "//tensorflow/python:state_ops",
"//tensorflow/python:training", "//tensorflow/python:training",
"//tensorflow/python:variables", "//tensorflow/python:variables",
"@absl_py//absl/testing:parameterized",
], ],
) )

View File

@ -33,6 +33,7 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat from tensorflow.python.util import compat
@ -40,6 +41,8 @@ from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
# Base class for the SavedModelBuilder that is only used by Tensorflow
# internally. Please use tf.compat.v1.saved_model.SavedModelBuilder instead.
class _SavedModelBuilder(object): class _SavedModelBuilder(object):
"""Builds the `SavedModel` protocol buffer and saves variables and assets. """Builds the `SavedModel` protocol buffer and saves variables and assets.
@ -144,52 +147,6 @@ class _SavedModelBuilder(object):
# Copy assets from source path to destination path. # Copy assets from source path to destination path.
self._copy_assets_to_destination_dir(asset_filename_map) self._copy_assets_to_destination_dir(asset_filename_map)
def _maybe_add_main_op(self, main_op):
"""Adds main op to the SavedModel.
Args:
main_op: Main op to run as part of graph initialization. If None, no
main op will be added to the graph.
Raises:
TypeError: if main op is provided but is not of type `Operation`.
ValueError: if the Graph already contains an init op.
"""
if main_op is None:
return
if not isinstance(main_op, ops.Operation):
raise TypeError("main_op needs to be an Operation: %r" % main_op)
# Validate that no other init ops have been added to this graph already.
# We check main_op and legacy_init_op for thoroughness and explicitness.
for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
if ops.get_collection(init_op_key):
raise ValueError(
"Graph already contains one or more main ops under the "
"collection {}.".format(init_op_key))
ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
def _add_train_op(self, train_op):
"""Add train op to the SavedModel.
Note that this functionality is in development, and liable to be
moved elsewhere.
Args:
train_op: Op or group of ops that are used for training. These are
stored as a collection with key TRAIN_OP_KEY, but not executed.
Raises:
TypeError if Train op is not of type `Operation`.
"""
if train_op is not None:
if (not isinstance(train_op, ops.Tensor) and
not isinstance(train_op, ops.Operation)):
raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op)
ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map): def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map):
"""Tags the meta graph def and adds it to the SavedModel. """Tags the meta graph def and adds it to the SavedModel.
@ -245,12 +202,16 @@ class _SavedModelBuilder(object):
Validation of entries in the signature def map includes ensuring that the Validation of entries in the signature def map includes ensuring that the
`name` and `dtype` fields of the TensorInfo protos of the `inputs` and `name` and `dtype` fields of the TensorInfo protos of the `inputs` and
`outputs` of each `SignatureDef` are populated. `outputs` of each `SignatureDef` are populated. Also ensures that reserved
SigantureDef keys for the initialization and train ops are not used.
Args: Args:
signature_def_map: The map of signature defs to be validated. signature_def_map: The map of signature defs to be validated.
Raises:
AssertionError: If a TensorInfo is not valid.
KeyError: If a reserved signature key is used in the map.
""" """
if signature_def_map is not None:
for signature_def_key in signature_def_map: for signature_def_key in signature_def_map:
signature_def = signature_def_map[signature_def_key] signature_def = signature_def_map[signature_def_key]
inputs = signature_def.inputs inputs = signature_def.inputs
@ -259,12 +220,14 @@ class _SavedModelBuilder(object):
self._validate_tensor_info(inputs[inputs_key]) self._validate_tensor_info(inputs[inputs_key])
for outputs_key in outputs: for outputs_key in outputs:
self._validate_tensor_info(outputs[outputs_key]) self._validate_tensor_info(outputs[outputs_key])
if constants.INIT_OP_SIGNATURE_KEY in signature_def_map:
def _add_collections(self, main_op, train_op): raise KeyError(
"""Add asset and op collections to be saved.""" "SignatureDef map key \"{}\" is reserved for initialization. Please "
self._maybe_add_main_op(main_op) "use a different key.".format(constants.INIT_OP_SIGNATURE_KEY))
if constants.TRAIN_OP_SIGNATURE_KEY in signature_def_map:
self._add_train_op(train_op) raise KeyError(
"SignatureDef map key \"{}\" is reserved for the train op. Please "
"use a different key.".format(constants.TRAIN_OP_SIGNATURE_KEY))
def _maybe_create_saver(self, saver=None): def _maybe_create_saver(self, saver=None):
"""Creates a sharded saver if one does not already exist.""" """Creates a sharded saver if one does not already exist."""
@ -278,19 +241,14 @@ class _SavedModelBuilder(object):
allow_empty=True) allow_empty=True)
return saver return saver
@deprecated_args(None,
"Pass your op to the equivalent parameter main_op instead.",
"legacy_init_op")
def add_meta_graph(self, def add_meta_graph(self,
tags, tags,
signature_def_map=None, signature_def_map=None,
assets_list=None, assets_list=None,
legacy_init_op=None,
clear_devices=False, clear_devices=False,
main_op=None, init_op=None,
strip_default_attrs=False, train_op=None,
saver=None): saver=None):
# pylint: disable=line-too-long
"""Adds the current meta graph to the SavedModel. """Adds the current meta graph to the SavedModel.
Creates a Saver in the current scope and uses the Saver to export the meta Creates a Saver in the current scope and uses the Saver to export the meta
@ -304,16 +262,14 @@ class _SavedModelBuilder(object):
assets_list: Assets to be saved with SavedModel. Note assets_list: Assets to be saved with SavedModel. Note
that this list should be a subset of the assets saved as part of that this list should be a subset of the assets saved as part of
the first meta graph in the SavedModel. the first meta graph in the SavedModel.
legacy_init_op: Legacy support for op or group of ops to execute after the
restore op upon a load. Deprecated; please use main_op instead.
clear_devices: Set to true if the device info on the default graph should clear_devices: Set to true if the device info on the default graph should
be cleared. be cleared.
main_op: Op or group of ops to execute when the graph is loaded. Note init_op: Op or group of ops to execute when the graph is loaded. Note
that when the main_op is specified it is run after the restore op at that when the init_op is specified it is run after the restore op at
load-time. load-time.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be train_op: Op or group of opts that trains the model when run. This will
removed from the NodeDefs. For a detailed guide, see not be run automatically when the graph is loaded, instead saved in
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). a SignatureDef accessible through the exported MetaGraph.
saver: An instance of tf.train.Saver that will be used to export the saver: An instance of tf.train.Saver that will be used to export the
metagraph. If None, a sharded Saver that restores all variables will metagraph. If None, a sharded Saver that restores all variables will
be used. be used.
@ -322,7 +278,6 @@ class _SavedModelBuilder(object):
AssertionError: If the variables for the SavedModel have not been saved AssertionError: If the variables for the SavedModel have not been saved
yet, or if the graph already contains one or more legacy init ops. yet, or if the graph already contains one or more legacy init ops.
""" """
# pylint: enable=line-too-long
if not self._has_saved_variables: if not self._has_saved_variables:
raise AssertionError( raise AssertionError(
"Graph state including variables and assets has not been saved yet. " "Graph state including variables and assets has not been saved yet. "
@ -330,14 +285,15 @@ class _SavedModelBuilder(object):
# Validate the signature def map to ensure all included TensorInfos are # Validate the signature def map to ensure all included TensorInfos are
# properly populated. # properly populated.
signature_def_map = signature_def_map or {}
self._validate_signature_def_map(signature_def_map) self._validate_signature_def_map(signature_def_map)
# legacy_init_op is deprecated, and going away in TF 2.0. # Create a SignatureDef pointing to the graph initialization op, which will
# Re-mapping to main_op, as treatment is identical regardless. # be added to the MetaGraphDef.
main_op = main_op or legacy_init_op _add_op_to_signature_def_map(signature_def_map, init_op,
constants.INIT_OP_SIGNATURE_KEY)
# Add ops to collection. _add_op_to_signature_def_map(signature_def_map, train_op,
self._add_collections(main_op=main_op, train_op=None) constants.TRAIN_OP_SIGNATURE_KEY)
saver = self._maybe_create_saver(saver) saver = self._maybe_create_saver(saver)
@ -349,7 +305,7 @@ class _SavedModelBuilder(object):
# resolved, we just leave the option set to False for now. # resolved, we just leave the option set to False for now.
# TODO(soergel): Reinstate clear_extraneous_savers=True when possible. # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
meta_graph_def = saver.export_meta_graph( meta_graph_def = saver.export_meta_graph(
clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) clear_devices=clear_devices, strip_default_attrs=True)
# Save asset files and write them to disk, if any. # Save asset files and write them to disk, if any.
self._save_and_write_assets(meta_graph_def, assets_list) self._save_and_write_assets(meta_graph_def, assets_list)
@ -357,17 +313,14 @@ class _SavedModelBuilder(object):
# Tag the meta graph def and add it to the SavedModel. # Tag the meta graph def and add it to the SavedModel.
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
@deprecated_args(None,
"Pass your op to the equivalent parameter main_op instead.",
"legacy_init_op")
def add_meta_graph_and_variables(self, def add_meta_graph_and_variables(self,
sess, sess,
tags, tags,
signature_def_map=None, signature_def_map=None,
assets_list=None, assets_list=None,
legacy_init_op=None,
clear_devices=False, clear_devices=False,
main_op=None, init_op=None,
train_op=None,
strip_default_attrs=False, strip_default_attrs=False,
saver=None): saver=None):
# pylint: disable=line-too-long # pylint: disable=line-too-long
@ -386,13 +339,14 @@ class _SavedModelBuilder(object):
signature_def_map: The map of signature def map to add to the meta graph signature_def_map: The map of signature def map to add to the meta graph
def. def.
assets_list: Assets to be saved with SavedModel. assets_list: Assets to be saved with SavedModel.
legacy_init_op: Legacy support for op or group of ops to execute after the
restore op upon a load. Deprecated; please use main_op instead.
clear_devices: Set to true if the device info on the default graph should clear_devices: Set to true if the device info on the default graph should
be cleared. be cleared.
main_op: Op or group of ops to execute when the graph is loaded. Note init_op: Op or group of ops to execute when the graph is loaded. Note
that when the main_op is specified it is run after the restore op at that when the init_op is specified it is run after the restore op at
load-time. load-time.
train_op: Op or group of ops that trains the model when run. This will
not be run automatically when the graph is loaded, instead saved in
a SignatureDef accessible through the exported MetaGraph.
strip_default_attrs: Boolean. If `True`, default-valued attributes will be strip_default_attrs: Boolean. If `True`, default-valued attributes will be
removed from the NodeDefs. For a detailed guide, see removed from the NodeDefs. For a detailed guide, see
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
@ -409,14 +363,15 @@ class _SavedModelBuilder(object):
# Validate the signature def map to ensure all included TensorInfos are # Validate the signature def map to ensure all included TensorInfos are
# properly populated. # properly populated.
signature_def_map = signature_def_map or {}
self._validate_signature_def_map(signature_def_map) self._validate_signature_def_map(signature_def_map)
# legacy_init_op is deprecated, and going away in TF 2.0. # Create a SignatureDef pointing to the graph initialization op, which will
# Re-mapping to main_op, as treatment is identical regardless. # be added to the MetaGraphDef.
main_op = main_op or legacy_init_op _add_op_to_signature_def_map(signature_def_map, init_op,
constants.INIT_OP_SIGNATURE_KEY)
# Add ops to collection. _add_op_to_signature_def_map(signature_def_map, train_op,
self._add_collections(main_op=main_op, train_op=None) constants.TRAIN_OP_SIGNATURE_KEY)
saved_model_utils.get_or_create_variables_dir(self._export_dir) saved_model_utils.get_or_create_variables_dir(self._export_dir)
variables_path = saved_model_utils.get_variables_path(self._export_dir) variables_path = saved_model_utils.get_variables_path(self._export_dir)
@ -517,6 +472,52 @@ class SavedModelBuilder(_SavedModelBuilder):
# Copy assets from source path to destination path. # Copy assets from source path to destination path.
self._copy_assets_to_destination_dir(asset_filename_map) self._copy_assets_to_destination_dir(asset_filename_map)
def _maybe_add_main_op(self, main_op):
"""Adds main op to the SavedModel.
Args:
main_op: Main op to run as part of graph initialization. If None, no main
op will be added to the graph.
Raises:
TypeError: if main op is provided but is not of type `Operation`.
ValueError: if the Graph already contains an init op.
"""
if main_op is None:
return
if not isinstance(main_op, ops.Operation):
raise TypeError("main_op needs to be an Operation: %r" % main_op)
# Validate that no other init ops have been added to this graph already.
# We check main_op and legacy_init_op for thoroughness and explicitness.
for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
if ops.get_collection(init_op_key):
raise ValueError(
"Graph already contains one or more main ops under the "
"collection {}.".format(init_op_key))
ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
def _add_train_op(self, train_op):
"""Add train op to the SavedModel.
Note that this functionality is in development, and liable to be
moved elsewhere.
Args:
train_op: Op or group of ops that are used for training. These are stored
as a collection with key TRAIN_OP_KEY, but not executed.
Raises:
TypeError if Train op is not of type `Operation`.
"""
if train_op is not None:
if (not isinstance(train_op, ops.Tensor) and
not isinstance(train_op, ops.Operation)):
raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op)
ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
@deprecated_args(None, @deprecated_args(None,
"Pass your op to the equivalent parameter main_op instead.", "Pass your op to the equivalent parameter main_op instead.",
"legacy_init_op") "legacy_init_op")
@ -536,6 +537,7 @@ class SavedModelBuilder(_SavedModelBuilder):
# Validate the signature def map to ensure all included TensorInfos are # Validate the signature def map to ensure all included TensorInfos are
# properly populated. # properly populated.
signature_def_map = signature_def_map or {}
self._validate_signature_def_map(signature_def_map) self._validate_signature_def_map(signature_def_map)
# legacy_init_op is deprecated, and going away in TF 2.0. # legacy_init_op is deprecated, and going away in TF 2.0.
@ -580,6 +582,7 @@ class SavedModelBuilder(_SavedModelBuilder):
# Validate the signature def map to ensure all included TensorInfos are # Validate the signature def map to ensure all included TensorInfos are
# properly populated. # properly populated.
signature_def_map = signature_def_map or {}
self._validate_signature_def_map(signature_def_map) self._validate_signature_def_map(signature_def_map)
# legacy_init_op is deprecated, and going away in TF 2.0. # legacy_init_op is deprecated, and going away in TF 2.0.
@ -774,3 +777,8 @@ def _add_asset_to_collection(asset_filename, asset_tensor):
asset_any_proto = Any() asset_any_proto = Any()
asset_any_proto.Pack(asset_proto) asset_any_proto.Pack(asset_proto)
ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto) ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)
def _add_op_to_signature_def_map(signature_def_map, op, key):
if op is not None:
signature_def_map[key] = signature_def_utils.op_signature_def(op, key)

View File

@ -48,7 +48,6 @@ tf_export(
# CollectionDef key for the SavedModel main op. # CollectionDef key for the SavedModel main op.
MAIN_OP_KEY = "saved_model_main_op" MAIN_OP_KEY = "saved_model_main_op"
tf_export( tf_export(
"saved_model.MAIN_OP_KEY",
v1=["saved_model.MAIN_OP_KEY", v1=["saved_model.MAIN_OP_KEY",
"saved_model.constants.MAIN_OP_KEY"]).export_constant( "saved_model.constants.MAIN_OP_KEY"]).export_constant(
__name__, "MAIN_OP_KEY") __name__, "MAIN_OP_KEY")
@ -105,3 +104,8 @@ tf_export(
"saved_model.VARIABLES_FILENAME", "saved_model.VARIABLES_FILENAME",
"saved_model.constants.VARIABLES_FILENAME" "saved_model.constants.VARIABLES_FILENAME"
]).export_constant(__name__, "VARIABLES_FILENAME") ]).export_constant(__name__, "VARIABLES_FILENAME")
# The initialization and train ops for a MetaGraph are stored in the
# signature def map. The ops are added to the map with the following keys.
INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"
TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op"

View File

@ -31,6 +31,7 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat from tensorflow.python.util import compat
@ -141,15 +142,46 @@ def _get_main_op_tensor(
RuntimeError: If the collection def corresponding to the main op key has RuntimeError: If the collection def corresponding to the main op key has
other than exactly one tensor. other than exactly one tensor.
""" """
# TODO(kathywu): Rename this method to _get_op_from_collection when
# dependency from SavedModelEstimator is removed.
collection_def = meta_graph_def_to_load.collection_def collection_def = meta_graph_def_to_load.collection_def
main_op_tensor = None init_op = None
if init_op_key in collection_def: if init_op_key in collection_def:
main_ops = collection_def[init_op_key].node_list.value init_op_list = collection_def[init_op_key].node_list.value
if len(main_ops) != 1: if len(init_op_list) != 1:
raise RuntimeError("Expected exactly one SavedModel main op. " raise RuntimeError("Expected exactly one SavedModel init op. "
"Found: {}".format(main_ops)) "Found: {}".format(init_op_list))
main_op_tensor = ops.get_collection(init_op_key)[0] init_op = ops.get_collection(init_op_key)[0]
return main_op_tensor return init_op
def _get_op_from_collection(meta_graph_def, op_key):
return _get_main_op_tensor(meta_graph_def, op_key)
def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope):
"""Retrieve op stored in the imported meta graph's signature def."""
if op_signature_key in meta_graph_def.signature_def:
return signature_def_utils.load_op_from_signature_def(
meta_graph_def.signature_def[op_signature_key], op_signature_key,
import_scope)
else:
return None
def get_init_op(meta_graph_def, import_scope=None):
return (_get_op_from_signature_def(
meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or
_get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or
_get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
def get_train_op(meta_graph_def, import_scope=None):
train_op = _get_op_from_signature_def(
meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope)
if train_op is None:
train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY)
return train_op
@tf_export(v1=[ @tf_export(v1=[
@ -359,11 +391,9 @@ class SavedModelLoader(object):
asset_tensors_dictionary = _get_asset_tensors( asset_tensors_dictionary = _get_asset_tensors(
self._export_dir, meta_graph_def, import_scope=import_scope) self._export_dir, meta_graph_def, import_scope=import_scope)
main_op_tensor = ( init_op = get_init_op(meta_graph_def, import_scope)
_get_main_op_tensor(meta_graph_def, constants.MAIN_OP_KEY) or if init_op is not None:
_get_main_op_tensor(meta_graph_def, constants.LEGACY_INIT_OP_KEY)) sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary)
if main_op_tensor is not None:
sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
def load(self, sess, tags, import_scope=None, **saver_kwargs): def load(self, sess, tags, import_scope=None, **saver_kwargs):
"""Load the MetaGraphDef graph and restore variable values into the session. """Load the MetaGraphDef graph and restore variable values into the session.

View File

@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import shutil
from absl.testing import parameterized
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import state_ops from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
@ -42,55 +44,72 @@ SIMPLE_ADD_SAVED_MODEL = _get_export_dir("simple_add_saved_model")
SAVED_MODEL_WITH_MAIN_OP = _get_export_dir("saved_model_with_main_op") SAVED_MODEL_WITH_MAIN_OP = _get_export_dir("saved_model_with_main_op")
class SavedModelLoaderTest(test.TestCase): def build_graph_helper():
g = ops.Graph()
def setUp(self): with g.as_default():
"""Write test SavedModels to a temp directory."""
with session.Session(graph=ops.Graph()) as sess:
x = variables.VariableV1(5, name="x") x = variables.VariableV1(5, name="x")
y = variables.VariableV1(11, name="y") y = variables.VariableV1(11, name="y")
z = x + y z = x + y
foo_sig_def = signature_def_utils.build_signature_def({
"foo_input": utils.build_tensor_info(x)
}, {"foo_output": utils.build_tensor_info(z)})
bar_sig_def = signature_def_utils.build_signature_def({
"bar_x": utils.build_tensor_info(x),
"bar_y": utils.build_tensor_info(y)
}, {"bar_z": utils.build_tensor_info(z)})
return g, {"foo": foo_sig_def, "bar": bar_sig_def}, y
@parameterized.parameters((saved_model_builder.SavedModelBuilder,),
(saved_model_builder._SavedModelBuilder,))
class SavedModelLoaderTest(test.TestCase, parameterized.TestCase):
def export_simple_graph(self, builder_cls):
g, sig_def_map, _ = build_graph_helper()
with session.Session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer()) self.evaluate(variables.global_variables_initializer())
builder = builder_cls(SIMPLE_ADD_SAVED_MODEL)
foo_sig_def = signature_def_utils.build_signature_def( builder.add_meta_graph_and_variables(sess, ["foo_graph"], sig_def_map)
{"foo_input": utils.build_tensor_info(x)},
{"foo_output": utils.build_tensor_info(z)})
bar_sig_def = signature_def_utils.build_signature_def(
{"bar_x": utils.build_tensor_info(x),
"bar_y": utils.build_tensor_info(y)},
{"bar_z": utils.build_tensor_info(z)})
builder = saved_model_builder.SavedModelBuilder(SIMPLE_ADD_SAVED_MODEL)
builder.add_meta_graph_and_variables(
sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def})
builder.save() builder.save()
# Write SavedModel with a main_op def export_graph_with_main_op(self, builder_cls):
g, sig_def_map, y = build_graph_helper()
with session.Session(graph=g) as sess:
self.evaluate(variables.global_variables_initializer())
assign_op = control_flow_ops.group(state_ops.assign(y, 7)) assign_op = control_flow_ops.group(state_ops.assign(y, 7))
builder = saved_model_builder.SavedModelBuilder(SAVED_MODEL_WITH_MAIN_OP) builder = builder_cls(SAVED_MODEL_WITH_MAIN_OP)
if builder_cls == saved_model_builder._SavedModelBuilder:
builder.add_meta_graph_and_variables( builder.add_meta_graph_and_variables(
sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def}, sess, ["foo_graph"], sig_def_map, init_op=assign_op)
main_op=assign_op) else:
builder.add_meta_graph_and_variables(
sess, ["foo_graph"], sig_def_map, main_op=assign_op)
builder.save() builder.save()
def tearDown(self): def tearDown(self):
file_io.delete_recursively(test.get_temp_dir()) super(SavedModelLoaderTest, self).tearDown()
shutil.rmtree(test.get_temp_dir(), ignore_errors=True)
def test_load_function(self): def test_load_function(self, builder_cls):
self.export_simple_graph(builder_cls)
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
with self.session(graph=ops.Graph()) as sess: with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo_graph"]) loader.load(sess, ["foo_graph"])
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
self.export_graph_with_main_op(builder_cls)
loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
with self.session(graph=ops.Graph()) as sess: with self.session(graph=ops.Graph()) as sess:
loader2.load(sess, ["foo_graph"]) loader2.load(sess, ["foo_graph"])
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval()) self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
def test_load_graph(self): def test_load_graph(self, builder_cls):
self.export_simple_graph(builder_cls)
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
graph = ops.Graph() graph = ops.Graph()
loader.load_graph(graph, ["foo_graph"]) loader.load_graph(graph, ["foo_graph"])
@ -101,14 +120,15 @@ class SavedModelLoaderTest(test.TestCase):
with self.assertRaises(KeyError): with self.assertRaises(KeyError):
graph.get_tensor_by_name("z:0") graph.get_tensor_by_name("z:0")
with self.session(graph=graph) as sess: with self.session(graph=graph):
# Check that x and y are not initialized # Check that x and y are not initialized
with self.assertRaises(errors.FailedPreconditionError): with self.assertRaises(errors.FailedPreconditionError):
self.evaluate(x) self.evaluate(x)
with self.assertRaises(errors.FailedPreconditionError): with self.assertRaises(errors.FailedPreconditionError):
self.evaluate(y) self.evaluate(y)
def test_load_with_import_scope(self): def test_load_with_import_scope(self, builder_cls):
self.export_graph_with_main_op(builder_cls)
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
with self.session(graph=ops.Graph()) as sess: with self.session(graph=ops.Graph()) as sess:
saver, _ = loader.load_graph( saver, _ = loader.load_graph(
@ -119,6 +139,12 @@ class SavedModelLoaderTest(test.TestCase):
loader.restore_variables(sess, tf_saver.Saver()) loader.restore_variables(sess, tf_saver.Saver())
loader.restore_variables(sess, saver) loader.restore_variables(sess, saver)
if builder_cls == saved_model_builder._SavedModelBuilder:
with self.assertRaises(errors.NotFoundError):
loader.run_init_ops(sess, ["foo_graph"])
loader.run_init_ops(sess, ["foo_graph"], import_scope="baz")
else:
loader.run_init_ops(sess, ["foo_graph"]) loader.run_init_ops(sess, ["foo_graph"])
self.assertEqual(5, sess.graph.get_tensor_by_name("baz/x:0").eval()) self.assertEqual(5, sess.graph.get_tensor_by_name("baz/x:0").eval())
@ -131,7 +157,8 @@ class SavedModelLoaderTest(test.TestCase):
self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval()) self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval())
self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval()) self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval())
def test_restore_variables(self): def test_restore_variables(self, builder_cls):
self.export_graph_with_main_op(builder_cls)
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
with self.session(graph=ops.Graph()) as sess: with self.session(graph=ops.Graph()) as sess:
x = variables.VariableV1(0, name="x") x = variables.VariableV1(0, name="x")
@ -147,7 +174,8 @@ class SavedModelLoaderTest(test.TestCase):
loader.restore_variables(sess, tf_saver.Saver()) loader.restore_variables(sess, tf_saver.Saver())
self.assertEqual(55, self.evaluate(z)) self.assertEqual(55, self.evaluate(z))
def test_run_init_op(self): def test_run_init_op(self, builder_cls):
self.export_graph_with_main_op(builder_cls)
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP) loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
graph = ops.Graph() graph = ops.Graph()
saver, _ = loader.load_graph(graph, ["foo_graph"]) saver, _ = loader.load_graph(graph, ["foo_graph"])
@ -160,14 +188,16 @@ class SavedModelLoaderTest(test.TestCase):
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval()) self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
def test_parse_saved_model(self): def test_parse_saved_model(self, builder_cls):
self.export_simple_graph(builder_cls)
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
meta_graph = loader.get_meta_graph_def_from_tags(["foo_graph"]) meta_graph = loader.get_meta_graph_def_from_tags(["foo_graph"])
self.assertIsNotNone(meta_graph) self.assertIsNotNone(meta_graph)
self.assertIn("foo", meta_graph.signature_def) self.assertIn("foo", meta_graph.signature_def)
self.assertIn("bar", meta_graph.signature_def) self.assertIn("bar", meta_graph.signature_def)
def test_load_invalid_meta_graph(self): def test_load_invalid_meta_graph(self, builder_cls):
self.export_simple_graph(builder_cls)
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
loader.get_meta_graph_def_from_tags([]) loader.get_meta_graph_def_from_tags([])
@ -176,13 +206,16 @@ class SavedModelLoaderTest(test.TestCase):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
loader.get_meta_graph_def_from_tags(["not_a_graph"]) loader.get_meta_graph_def_from_tags(["not_a_graph"])
def test_load_saved_model_with_no_variables(self): def test_load_saved_model_with_no_variables(self, builder_cls):
"""Test that SavedModel runs saver when there appear to be no variables. """Test that SavedModel runs saver when there appear to be no variables.
When no variables are detected, this may mean that the variables were saved When no variables are detected, this may mean that the variables were saved
to different collections, or the collections weren't saved to the to different collections, or the collections weren't saved to the
SavedModel. If the SavedModel MetaGraphDef contains a saver, it should still SavedModel. If the SavedModel MetaGraphDef contains a saver, it should still
run in either of these cases. run in either of these cases.
Args:
builder_cls: SavedModelBuilder or _SavedModelBuilder class
""" """
path = _get_export_dir("no_variable_saved_model") path = _get_export_dir("no_variable_saved_model")
with session.Session(graph=ops.Graph()) as sess: with session.Session(graph=ops.Graph()) as sess:
@ -192,7 +225,7 @@ class SavedModelLoaderTest(test.TestCase):
11, name="y", collections=["not_global_variable"]) 11, name="y", collections=["not_global_variable"])
self.assertFalse(variables._all_saveable_objects()) self.assertFalse(variables._all_saveable_objects())
z = x + y z = x + y
sess.run(variables.variables_initializer([x, y])) self.evaluate(variables.variables_initializer([x, y]))
foo_sig_def = signature_def_utils.build_signature_def( foo_sig_def = signature_def_utils.build_signature_def(
{"foo_input": utils.build_tensor_info(x)}, {"foo_input": utils.build_tensor_info(x)},
@ -215,8 +248,9 @@ class SavedModelLoaderTest(test.TestCase):
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval()) self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval()) self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
def test_load_saved_model_graph_with_return_elements(self): def test_load_saved_model_graph_with_return_elements(self, builder_cls):
"""Ensure that the correct elements are returned.""" """Ensure that the correct elements are returned."""
self.export_simple_graph(builder_cls)
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL) loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
graph = ops.Graph() graph = ops.Graph()
_, ret = loader.load_graph(graph, ["foo_graph"], _, ret = loader.load_graph(graph, ["foo_graph"],
@ -228,5 +262,6 @@ class SavedModelLoaderTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "not found in graph"): with self.assertRaisesRegexp(ValueError, "not found in graph"):
loader.load_graph(graph, ["foo_graph"], return_elements=["z:0"]) loader.load_graph(graph, ["foo_graph"], return_elements=["z:0"])
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -146,6 +146,18 @@ class SavedModelTest(SavedModelTestBase):
sess, ["foo"], sess, ["foo"],
signature_def_map={"foo_key": foo_signature}) signature_def_map={"foo_key": foo_signature})
def _validate_sig_def_keys(self, builder, valid_tensor_info, invalid_key):
with self.session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
foo_signature = signature_def_utils.build_signature_def(
dict(), {"foo_key": valid_tensor_info}, "foo")
self.assertRaises(
KeyError,
builder.add_meta_graph_and_variables,
sess, ["foo"],
signature_def_map={invalid_key: foo_signature})
def testMaybeSavedModelDir(self): def testMaybeSavedModelDir(self):
base_path = test.test_src_dir_path("/python/saved_model") base_path = test.test_src_dir_path("/python/saved_model")
self.assertFalse(loader.maybe_saved_model_directory(base_path)) self.assertFalse(loader.maybe_saved_model_directory(base_path))
@ -583,6 +595,15 @@ class SavedModelTest(SavedModelTestBase):
self._validate_inputs_tensor_info_fail(builder, tensor_empty) self._validate_inputs_tensor_info_fail(builder, tensor_empty)
self._validate_outputs_tensor_info_fail(builder, tensor_empty) self._validate_outputs_tensor_info_fail(builder, tensor_empty)
valid_tensor_info = meta_graph_pb2.TensorInfo()
valid_tensor_info.name = "foo"
valid_tensor_info.dtype = types_pb2.DT_FLOAT
self._validate_sig_def_keys(builder, valid_tensor_info,
constants.INIT_OP_SIGNATURE_KEY)
self._validate_sig_def_keys(builder, valid_tensor_info,
constants.TRAIN_OP_SIGNATURE_KEY)
def testSignatureDefValidationSucceedsWithName(self): def testSignatureDefValidationSucceedsWithName(self):
tensor_with_name = meta_graph_pb2.TensorInfo() tensor_with_name = meta_graph_pb2.TensorInfo()
tensor_with_name.name = "foo" tensor_with_name.name = "foo"
@ -782,7 +803,7 @@ class SavedModelTest(SavedModelTestBase):
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt", self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
"foo bar baz 0", "asset_file_tensor_0:0") "foo bar baz 0", "asset_file_tensor_0:0")
def testCustomMainOp(self): def testCustomInitOp(self):
export_dir = self._get_export_dir("test_main_op") export_dir = self._get_export_dir("test_main_op")
builder = saved_model_builder._SavedModelBuilder(export_dir) builder = saved_model_builder._SavedModelBuilder(export_dir)
@ -800,11 +821,11 @@ class SavedModelTest(SavedModelTestBase):
# Set up an assignment op to be run as part of the main_op. # Set up an assignment op to be run as part of the main_op.
with ops.control_dependencies([main_op.main_op()]): with ops.control_dependencies([main_op.main_op()]):
add_v1_v2 = math_ops.add(v1._ref(), v2._ref()) add_v1_v2 = math_ops.add(v1._ref(), v2._ref())
custom_main_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2)) custom_init_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))
self.evaluate(custom_main_op) self.evaluate(custom_init_op)
builder.add_meta_graph_and_variables( builder.add_meta_graph_and_variables(
sess, ["foo"], main_op=custom_main_op) sess, ["foo"], init_op=custom_init_op)
# Save the SavedModel to disk. # Save the SavedModel to disk.
builder.save() builder.save()
@ -817,80 +838,6 @@ class SavedModelTest(SavedModelTestBase):
# the main_op, following a restore. # the main_op, following a restore.
self.assertEqual(3, ops.get_collection("v")[2].eval()) self.assertEqual(3, ops.get_collection("v")[2].eval())
def testLegacyInitOp(self):
export_dir = self._get_export_dir("test_legacy_init_op")
builder = saved_model_builder._SavedModelBuilder(export_dir)
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
v3 = variables.VariableV1(42, name="v3", trainable=False, collections=[])
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the legacy_init_op.
assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")
self.evaluate(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
sess, ["foo"], legacy_init_op=legacy_init_op)
# Save the SavedModel to disk.
builder.save()
with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
# Evaluates to the sum of the first two variables and assigned as part of
# the legacy_init_op, following a restore.
self.assertEqual(3, ops.get_collection("v")[2].eval())
def testLegacyInitOpWithNonEmptyCollection(self):
export_dir = self._get_export_dir(
"test_legacy_init_op_with_non_empty_collection")
self._testInitOpsWithNonEmptyCollection(
export_dir, constants.LEGACY_INIT_OP_KEY)
def testMainOpWithNonEmptyCollection(self):
export_dir = self._get_export_dir(
"test_main_op_with_non_empty_collection")
self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY)
def _testInitOpsWithNonEmptyCollection(self, export_dir, key):
builder = saved_model_builder._SavedModelBuilder(export_dir)
g = ops.Graph()
with self.session(graph=g) as sess:
# Initialize variable `v1` to 1.
v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
# Initialize another variable `v2` to 42.
v2 = variables.VariableV1(42, name="v2", trainable=False, collections=[])
ops.add_to_collection("v", v2)
# Set up an assignment op to be run as part of the init op.
assign_v2 = state_ops.assign(v2, v1)
init_op = control_flow_ops.group(assign_v2, name="init_op")
self.evaluate(variables.global_variables_initializer())
ops.add_to_collection(key, control_flow_ops.no_op())
# ValueError should be raised since the LEGACY_INIT_OP_KEY collection
# is not empty and we don't support multiple init ops.
with self.assertRaisesRegexp(ValueError, "Graph already contains"):
builder.add_meta_graph_and_variables(
sess, ["foo"], legacy_init_op=init_op)
# We shouldn't be able to add as MAIN_OP, either.
with self.assertRaisesRegexp(ValueError, "Graph already contains"):
builder.add_meta_graph_and_variables(sess, ["foo"], main_op=init_op)
def testTrainOp(self): def testTrainOp(self):
export_dir = self._get_export_dir("test_train_op") export_dir = self._get_export_dir("test_train_op")
builder = saved_model_builder._SavedModelBuilder(export_dir) builder = saved_model_builder._SavedModelBuilder(export_dir)
@ -906,19 +853,17 @@ class SavedModelTest(SavedModelTestBase):
train_op = state_ops.assign_add(v1, v2) train_op = state_ops.assign_add(v1, v2)
self.evaluate(train_op) self.evaluate(train_op)
# TODO(karmel): remove explicit call when in the public method. builder.add_meta_graph_and_variables(sess, ["foo"], train_op=train_op)
builder._add_train_op(train_op)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Save the SavedModel to disk. # Save the SavedModel to disk.
builder.save() builder.save()
with self.session(graph=ops.Graph()) as sess: with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir) meta_graph_def = loader.load(sess, ["foo"], export_dir)
self.assertEqual(3, ops.get_collection("v")[0].eval()) self.assertEqual(3, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval()) self.assertEqual(2, ops.get_collection("v")[1].eval())
self.assertIsInstance( self.assertIsInstance(
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor) loader_impl.get_train_op(meta_graph_def), ops.Tensor)
def testTrainOpGroup(self): def testTrainOpGroup(self):
export_dir = self._get_export_dir("test_train_op_group") export_dir = self._get_export_dir("test_train_op_group")
@ -935,19 +880,17 @@ class SavedModelTest(SavedModelTestBase):
train_op = control_flow_ops.group() train_op = control_flow_ops.group()
self.evaluate(train_op) self.evaluate(train_op)
# TODO(karmel): remove explicit call when in the public method. builder.add_meta_graph_and_variables(sess, ["foo"], train_op=train_op)
builder._add_train_op(train_op)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Save the SavedModel to disk. # Save the SavedModel to disk.
builder.save() builder.save()
with self.session(graph=ops.Graph()) as sess: with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir) meta_graph_def = loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval()) self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval()) self.assertEqual(2, ops.get_collection("v")[1].eval())
self.assertIsInstance( self.assertIsInstance(
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation) loader_impl.get_train_op(meta_graph_def), ops.Operation)
def testTrainOpAfterVariables(self): def testTrainOpAfterVariables(self):
export_dir = self._get_export_dir("test_train_op_after_variables") export_dir = self._get_export_dir("test_train_op_after_variables")
@ -965,17 +908,15 @@ class SavedModelTest(SavedModelTestBase):
train_op = state_ops.assign_add(v1, v2) train_op = state_ops.assign_add(v1, v2)
self.evaluate(train_op) self.evaluate(train_op)
# TODO(karmel): remove explicit call when in the public method. builder.add_meta_graph(["foo"], train_op=train_op)
builder._add_train_op(train_op)
builder.add_meta_graph(["foo"])
# Save the SavedModel to disk. # Save the SavedModel to disk.
builder.save() builder.save()
with self.session(graph=ops.Graph()) as sess: with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir) meta_graph_def = loader.load(sess, ["foo"], export_dir)
self.assertIsInstance( self.assertIsInstance(
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor) loader_impl.get_train_op(meta_graph_def), ops.Tensor)
with self.session(graph=ops.Graph()) as sess: with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["pre_foo"], export_dir) loader.load(sess, ["pre_foo"], export_dir)
@ -1288,76 +1229,6 @@ class SavedModelTest(SavedModelTestBase):
self.assertEqual( self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
def testStripDefaultAttrs(self):
export_dir = self._get_export_dir("test_strip_default_attrs")
builder = saved_model_builder._SavedModelBuilder(export_dir)
# Add a graph with two float32 variables and a Complex Op composing them
# with strip_default_attrs enabled.
with session.Session(graph=ops.Graph()) as sess:
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
self.evaluate(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
sess, ["foo"], strip_default_attrs=True)
# Add a graph with the same float32 variables and a Complex Op composing
# them with strip_default_attrs disabled.
with session.Session(graph=ops.Graph()) as sess:
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
self.evaluate(variables.global_variables_initializer())
builder.add_meta_graph(["bar"], strip_default_attrs=False)
# Save the SavedModel to disk in text format.
builder.save(as_text=True)
# Loading graph "foo" via the loader must restore the defaults for the
# "Complex" node based on the "Complex" OpDef in the Op registry.
sess = session.Session(graph=ops.Graph())
meta_graph_def = loader.load(sess, ["foo"], export_dir)
complex_node = test_util.get_node_def_from_graph("complex",
meta_graph_def.graph_def)
self.assertIn("T", complex_node.attr)
self.assertIn("Tout", complex_node.attr)
# Load graph "foo" from disk as-is to verify default attrs are stripped.
# pylint: disable=protected-access
saved_model_pb = loader_impl._parse_saved_model(export_dir)
self.assertIsNotNone(saved_model_pb)
# pylint: enable=protected-access
meta_graph_foo_def = None
meta_graph_bar_def = None
for meta_graph_def in saved_model_pb.meta_graphs:
if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
meta_graph_foo_def = meta_graph_def
elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
meta_graph_bar_def = meta_graph_def
self.assertIsNotNone(meta_graph_foo_def)
self.assertIsNotNone(meta_graph_bar_def)
# "Complex" Op has 2 attributes with defaults:
# o "T" : float32. (input type)
# o "Tout" : complex64. (output type)
# "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
# Graph "foo" was saved with strip_default_attrs set to True.
node_def = test_util.get_node_def_from_graph("complex",
meta_graph_foo_def.graph_def)
self.assertNotIn("T", node_def.attr)
self.assertNotIn("Tout", node_def.attr)
# "Complex" Op in graph "bar" must have attributes "T" and "Tout".
# Graph "bar" was saved with strip_default_attrs set to False.
node_def = test_util.get_node_def_from_graph("complex",
meta_graph_bar_def.graph_def)
self.assertIn("T", node_def.attr)
self.assertIn("Tout", node_def.attr)
# Tests the behavior of loading SavedModels that having missing attrs or attrs # Tests the behavior of loading SavedModels that having missing attrs or attrs
# with incorrect types. # with incorrect types.
def testInconsistentConsumerDefaultAttrs(self): def testInconsistentConsumerDefaultAttrs(self):
@ -1484,6 +1355,149 @@ class SavedModelV1Test(SavedModelTestBase):
compat.as_bytes("ignored.txt")) compat.as_bytes("ignored.txt"))
self.assertFalse(file_io.file_exists(ignored_asset_path)) self.assertFalse(file_io.file_exists(ignored_asset_path))
def testLegacyInitOpWithNonEmptyCollection(self):
export_dir = self._get_export_dir(
"test_legacy_init_op_with_non_empty_collection")
self._testInitOpsWithNonEmptyCollection(export_dir,
constants.LEGACY_INIT_OP_KEY)
def testMainOpWithNonEmptyCollection(self):
export_dir = self._get_export_dir("test_main_op_with_non_empty_collection")
self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY)
def _testInitOpsWithNonEmptyCollection(self, export_dir, key):
builder = saved_model_builder.SavedModelBuilder(export_dir)
g = ops.Graph()
with self.session(graph=g) as sess:
# Initialize variable `v1` to 1.
v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
# Initialize another variable `v2` to 42.
v2 = variables.VariableV1(42, name="v2", trainable=False, collections=[])
ops.add_to_collection("v", v2)
# Set up an assignment op to be run as part of the init op.
assign_v2 = state_ops.assign(v2, v1)
init_op = control_flow_ops.group(assign_v2, name="init_op")
self.evaluate(variables.global_variables_initializer())
ops.add_to_collection(key, control_flow_ops.no_op())
# ValueError should be raised since the LEGACY_INIT_OP_KEY collection
# is not empty and we don't support multiple init ops.
with self.assertRaisesRegexp(ValueError, "Graph already contains"):
builder.add_meta_graph_and_variables(
sess, ["foo"], legacy_init_op=init_op)
# We shouldn't be able to add as MAIN_OP, either.
with self.assertRaisesRegexp(ValueError, "Graph already contains"):
builder.add_meta_graph_and_variables(sess, ["foo"], main_op=init_op)
def testStripDefaultAttrs(self):
export_dir = self._get_export_dir("test_strip_default_attrs")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Add a graph with two float32 variables and a Complex Op composing them
# with strip_default_attrs enabled.
with session.Session(graph=ops.Graph()) as sess:
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
self.evaluate(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
sess, ["foo"], strip_default_attrs=True)
# Add a graph with the same float32 variables and a Complex Op composing
# them with strip_default_attrs disabled.
with session.Session(graph=ops.Graph()) as sess:
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
self.evaluate(variables.global_variables_initializer())
builder.add_meta_graph(["bar"], strip_default_attrs=False)
# Save the SavedModel to disk in text format.
builder.save(as_text=True)
# Loading graph "foo" via the loader must restore the defaults for the
# "Complex" node based on the "Complex" OpDef in the Op registry.
sess = session.Session(graph=ops.Graph())
meta_graph_def = loader.load(sess, ["foo"], export_dir)
complex_node = test_util.get_node_def_from_graph("complex",
meta_graph_def.graph_def)
self.assertIn("T", complex_node.attr)
self.assertIn("Tout", complex_node.attr)
# Load graph "foo" from disk as-is to verify default attrs are stripped.
# pylint: disable=protected-access
saved_model_pb = loader_impl._parse_saved_model(export_dir)
self.assertIsNotNone(saved_model_pb)
# pylint: enable=protected-access
meta_graph_foo_def = None
meta_graph_bar_def = None
for meta_graph_def in saved_model_pb.meta_graphs:
if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
meta_graph_foo_def = meta_graph_def
elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
meta_graph_bar_def = meta_graph_def
self.assertIsNotNone(meta_graph_foo_def)
self.assertIsNotNone(meta_graph_bar_def)
# "Complex" Op has 2 attributes with defaults:
# o "T" : float32. (input type)
# o "Tout" : complex64. (output type)
# "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
# Graph "foo" was saved with strip_default_attrs set to True.
node_def = test_util.get_node_def_from_graph("complex",
meta_graph_foo_def.graph_def)
self.assertNotIn("T", node_def.attr)
self.assertNotIn("Tout", node_def.attr)
# "Complex" Op in graph "bar" must have attributes "T" and "Tout".
# Graph "bar" was saved with strip_default_attrs set to False.
node_def = test_util.get_node_def_from_graph("complex",
meta_graph_bar_def.graph_def)
self.assertIn("T", node_def.attr)
self.assertIn("Tout", node_def.attr)
def testLegacyInitOp(self):
export_dir = self._get_export_dir("test_legacy_init_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
v3 = variables.VariableV1(42, name="v3", trainable=False, collections=[])
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the init_op.
assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")
self.evaluate(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
sess, ["foo"], legacy_init_op=legacy_init_op)
# Save the SavedModel to disk.
builder.save()
with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
# Evaluates to the sum of the first two variables and assigned as part of
# the legacy_init_op, following a restore.
self.assertEqual(3, ops.get_collection("v")[2].eval())
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -24,6 +24,8 @@ from __future__ import print_function
from tensorflow.python.saved_model.signature_def_utils_impl import build_signature_def from tensorflow.python.saved_model.signature_def_utils_impl import build_signature_def
from tensorflow.python.saved_model.signature_def_utils_impl import classification_signature_def from tensorflow.python.saved_model.signature_def_utils_impl import classification_signature_def
from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature
from tensorflow.python.saved_model.signature_def_utils_impl import load_op_from_signature_def
from tensorflow.python.saved_model.signature_def_utils_impl import op_signature_def
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def
from tensorflow.python.saved_model.signature_def_utils_impl import supervised_eval_signature_def from tensorflow.python.saved_model.signature_def_utils_impl import supervised_eval_signature_def

View File

@ -21,9 +21,10 @@ from __future__ import print_function
from tensorflow.core.framework import types_pb2 from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import utils from tensorflow.python.saved_model import utils_impl as utils
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
@ -349,3 +350,51 @@ def _is_valid_classification_signature(signature_def):
return False return False
return True return True
def op_signature_def(op, key):
"""Creates a signature def with the output pointing to an op.
Note that op isn't strictly enforced to be an Op object, and may be a Tensor.
It is recommended to use the build_signature_def() function for Tensors.
Args:
op: An Op (or possibly Tensor).
key: Key to graph element in the SignatureDef outputs.
Returns:
A SignatureDef with a single output pointing to the op.
"""
# Use build_tensor_info_from_op, which creates a TensorInfo from the element's
# name.
return build_signature_def(outputs={key: utils.build_tensor_info_from_op(op)})
def load_op_from_signature_def(signature_def, key, import_scope=None):
"""Load an Op from a SignatureDef created by op_signature_def().
Args:
signature_def: a SignatureDef proto
key: string key to op in the SignatureDef outputs.
import_scope: Scope used to import the op
Returns:
Op (or possibly Tensor) in the graph with the same name as saved in the
SignatureDef.
Raises:
NotFoundError: If the op could not be found in the graph.
"""
tensor_info = signature_def.outputs[key]
try:
# The init and train ops are not strictly enforced to be operations, so
# retrieve any graph element (can be either op or tensor).
return utils.get_element_from_tensor_info(
tensor_info, import_scope=import_scope)
except KeyError:
raise errors.NotFoundError(
None, None,
'The {0} could not be found in the graph. Please make sure the '
'SavedModel was created by the internal _SavedModelBuilder. If you '
'are using the public API, please make sure the SignatureDef in the '
'SavedModel does not contain the key "{0}".'.format(key))

View File

@ -23,6 +23,7 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils_impl from tensorflow.python.saved_model import signature_def_utils_impl
@ -413,5 +414,22 @@ class SignatureDefUtilsTest(test.TestCase):
{}, {},
signature_constants.PREDICT_METHOD_NAME) signature_constants.PREDICT_METHOD_NAME)
def testOpSignatureDef(self):
key = "adding_1_and_2_key"
add_op = math_ops.add(1, 2, name="adding_1_and_2")
signature_def = signature_def_utils_impl.op_signature_def(add_op, key)
self.assertIn(key, signature_def.outputs)
self.assertEqual(add_op.name, signature_def.outputs[key].name)
def testLoadOpFromSignatureDef(self):
key = "adding_1_and_2_key"
add_op = math_ops.add(1, 2, name="adding_1_and_2")
signature_def = signature_def_utils_impl.op_signature_def(add_op, key)
self.assertEqual(
add_op,
signature_def_utils_impl.load_op_from_signature_def(signature_def, key))
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -141,6 +141,27 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
raise ValueError("Invalid TensorInfo.encoding: %s" % encoding) raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None):
"""Returns the element in the graph described by a TensorInfo proto.
Args:
tensor_info: A TensorInfo proto describing an Op or Tensor by name.
graph: The tf.Graph in which tensors are looked up. If None, the current
default graph is used.
import_scope: If not None, names in `tensor_info` are prefixed with this
string before lookup.
Returns:
Op or tensor in `graph` described by `tensor_info`.
Raises:
KeyError: If `tensor_info` does not correspond to an op or tensor in `graph`
"""
graph = graph or ops.get_default_graph()
return graph.as_graph_element(
ops.prepend_name_scope(tensor_info.name, import_scope=import_scope))
# Path helpers. # Path helpers.

View File

@ -32,10 +32,6 @@ tf_module {
name: "GPU" name: "GPU"
mtype: "<type \'str\'>" mtype: "<type \'str\'>"
} }
member {
name: "MAIN_OP_KEY"
mtype: "<type \'str\'>"
}
member { member {
name: "PREDICT_INPUTS" name: "PREDICT_INPUTS"
mtype: "<type \'str\'>" mtype: "<type \'str\'>"