Save main_op and train_op in SignatureDefs instead of collections.
PiperOrigin-RevId: 223268983
This commit is contained in:
parent
99c20bf32e
commit
693c1ad4a9
@ -133,5 +133,6 @@ filegroup(
|
||||
"testdata/half_plus_two_pbtxt/**",
|
||||
"testdata/half_plus_two_main_op/**",
|
||||
"testdata/half_plus_two/**",
|
||||
"testdata/half_plus_two_v2/**",
|
||||
]),
|
||||
)
|
||||
|
@ -33,10 +33,10 @@ constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
|
||||
/// SavedModel text format proto filename.
|
||||
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
|
||||
|
||||
/// SavedModel legacy init op key.
|
||||
/// SavedModel legacy init op collection key. Used in v1 SavedModels.
|
||||
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
|
||||
|
||||
/// SavedModel main op key.
|
||||
/// SavedModel main op collection key. Used in v1 SavedModels.
|
||||
constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
|
||||
|
||||
/// Directory in which to save the SavedModel variables.
|
||||
@ -45,6 +45,11 @@ constexpr char kSavedModelVariablesDirectory[] = "variables";
|
||||
/// SavedModel variables filename.
|
||||
constexpr char kSavedModelVariablesFilename[] = "variables";
|
||||
|
||||
/// SavedModel SignatureDef keys for the initialization and train ops. Used in
|
||||
/// V2 SavedModels.
|
||||
constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op";
|
||||
constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op";
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
|
||||
|
@ -122,38 +122,58 @@ Status RunOnce(const RunOptions& run_options,
|
||||
return run_status;
|
||||
}
|
||||
|
||||
bool HasMainOp(const MetaGraphDef& meta_graph_def) {
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
||||
collection_def_map.end()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status RunMainOp(const RunOptions& run_options, const string& export_dir,
|
||||
// RunInitOp will return OK if the initialization op was run successfully.
|
||||
// An empty init_op_name indicates that there are no init ops to run.
|
||||
Status RunInitOp(const RunOptions& run_options, const string& export_dir,
|
||||
const MetaGraphDef& meta_graph_def,
|
||||
const std::vector<AssetFileDef>& asset_file_defs,
|
||||
Session* session, const string& main_op_key) {
|
||||
LOG(INFO) << "Running MainOp with key " << main_op_key
|
||||
<< " on SavedModel bundle.";
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto main_op_it = collection_def_map.find(main_op_key);
|
||||
if (main_op_it != collection_def_map.end()) {
|
||||
if (main_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
||||
}
|
||||
Session* session, const string& init_op_name) {
|
||||
if (!init_op_name.empty()) {
|
||||
LOG(INFO) << "Running initialization op on SavedModel bundle.";
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
|
||||
RunMetadata run_metadata;
|
||||
const StringPiece main_op_name = main_op_it->second.node_list().value(0);
|
||||
return RunOnce(run_options, inputs, {}, {string(main_op_name)},
|
||||
return RunOnce(run_options, inputs, {}, {init_op_name},
|
||||
nullptr /* outputs */, &run_metadata, session);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name) {
|
||||
const auto& sig_def_map = meta_graph_def.signature_def();
|
||||
const auto& init_op_sig_it =
|
||||
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
|
||||
if (init_op_sig_it != sig_def_map.end()) {
|
||||
*init_op_name = init_op_sig_it->second.outputs()
|
||||
.find(kSavedModelInitOpSignatureKey)
|
||||
->second.name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
string init_op_collection_key;
|
||||
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
||||
collection_def_map.end()) {
|
||||
init_op_collection_key = kSavedModelMainOpKey;
|
||||
} else {
|
||||
init_op_collection_key = kSavedModelLegacyInitOpKey;
|
||||
}
|
||||
|
||||
const auto init_op_it = collection_def_map.find(init_op_collection_key);
|
||||
if (init_op_it != collection_def_map.end()) {
|
||||
if (init_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
||||
}
|
||||
*init_op_name = init_op_it->second.node_list().value(0);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
const StringPiece restore_op_name,
|
||||
const StringPiece variable_filename_const_op_name,
|
||||
@ -236,15 +256,12 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
bundle->meta_graph_def.saver_def().restore_op_name(),
|
||||
bundle->meta_graph_def.saver_def().filename_tensor_name(),
|
||||
asset_file_defs, bundle->session.get()));
|
||||
if (HasMainOp(bundle->meta_graph_def)) {
|
||||
TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir,
|
||||
bundle->meta_graph_def, asset_file_defs,
|
||||
bundle->session.get(), kSavedModelMainOpKey));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(RunMainOp(
|
||||
run_options, export_dir, bundle->meta_graph_def, asset_file_defs,
|
||||
bundle->session.get(), kSavedModelLegacyInitOpKey));
|
||||
}
|
||||
string init_op_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
|
||||
asset_file_defs, bundle->session.get(),
|
||||
init_op_name));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -36,6 +36,8 @@ constexpr char kTestDataMainOp[] =
|
||||
"cc/saved_model/testdata/half_plus_two_main_op/00000123";
|
||||
constexpr char kTestDataSharded[] =
|
||||
"cc/saved_model/testdata/half_plus_two/00000123";
|
||||
constexpr char kTestDataInitOpV2[] =
|
||||
"cc/saved_model/testdata/half_plus_two_v2/00000123";
|
||||
|
||||
class LoaderTest : public ::testing::Test {
|
||||
protected:
|
||||
@ -227,5 +229,17 @@ TEST_F(LoaderTest, MaybeSavedModelDirectory) {
|
||||
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
|
||||
}
|
||||
|
||||
TEST_F(LoaderTest, SavedModelInitOpV2Format) {
|
||||
SavedModelBundle bundle;
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
|
||||
const string export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2);
|
||||
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||
{kSavedModelTagServe}, &bundle));
|
||||
CheckSavedModelBundle(export_dir, bundle);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
1
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt
vendored
Normal file
1
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt
vendored
Normal file
@ -0,0 +1 @@
|
||||
asset-file-contents
|
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index
vendored
Normal file
Binary file not shown.
@ -125,7 +125,7 @@ def save_keras_model(
|
||||
export_dir = export_helpers.get_timestamped_export_dir(saved_model_path)
|
||||
temp_export_dir = export_helpers.get_temp_export_dir(export_dir)
|
||||
|
||||
builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
|
||||
builder = saved_model_builder._SavedModelBuilder(temp_export_dir)
|
||||
|
||||
# Manually save variables to export them in an object-based checkpoint. This
|
||||
# skips the `builder.add_meta_graph_and_variables()` step, which saves a
|
||||
@ -227,9 +227,10 @@ def _export_mode(
|
||||
g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
|
||||
|
||||
# Extract update and train ops from train/test/predict functions.
|
||||
train_op = None
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
clone._make_train_function()
|
||||
builder._add_train_op(clone.train_function.updates_op)
|
||||
train_op = clone.train_function.updates_op
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
clone._make_test_function()
|
||||
else:
|
||||
@ -264,7 +265,8 @@ def _export_mode(
|
||||
model_fn_lib.EXPORT_TAG_MAP[mode],
|
||||
signature_def_map=_create_signature_def_map(clone, mode),
|
||||
saver=saver_lib.Saver(clone_var_list),
|
||||
main_op=variables.local_variables_initializer())
|
||||
init_op=variables.local_variables_initializer(),
|
||||
train_op=train_op)
|
||||
return None
|
||||
|
||||
|
||||
|
@ -35,7 +35,6 @@ from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.keras.utils import tf_utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import loader_impl
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.training import training as training_module
|
||||
@ -254,7 +253,7 @@ def load_model(sess, path, mode):
|
||||
outputs = {
|
||||
k: sess.graph.get_tensor_by_name(v.name)
|
||||
for k, v in meta_graph_def.signature_def[sig_def_key].outputs.items()}
|
||||
return inputs, outputs
|
||||
return inputs, outputs, meta_graph_def
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@ -331,7 +330,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
||||
|
||||
# Load predict graph, and test predictions
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
inputs, outputs = load_model(sess, output_path,
|
||||
inputs, outputs, _ = load_model(sess, output_path,
|
||||
model_fn_lib.ModeKeys.PREDICT)
|
||||
|
||||
predictions = sess.run(outputs[output_name],
|
||||
@ -341,7 +340,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
||||
if optimizer:
|
||||
# Load eval graph, and test predictions, loss and metric values
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
inputs, outputs = load_model(sess, output_path,
|
||||
inputs, outputs, _ = load_model(sess, output_path,
|
||||
model_fn_lib.ModeKeys.EVAL)
|
||||
|
||||
# First obtain the loss and predictions, and run the metric update op by
|
||||
@ -365,8 +364,8 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
||||
|
||||
# Load train graph, and check for the train op, and prediction values
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
inputs, outputs = load_model(sess, output_path,
|
||||
model_fn_lib.ModeKeys.TRAIN)
|
||||
inputs, outputs, meta_graph_def = load_model(
|
||||
sess, output_path, model_fn_lib.ModeKeys.TRAIN)
|
||||
self.assertEqual(int(train_before_export),
|
||||
sess.run(training_module.get_global_step()))
|
||||
self.assertIn('loss', outputs)
|
||||
@ -375,7 +374,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
||||
self.assertIn('predictions/' + output_name, outputs)
|
||||
|
||||
# Train for a step
|
||||
train_op = ops.get_collection(constants.TRAIN_OP_KEY)
|
||||
train_op = loader_impl.get_train_op(meta_graph_def)
|
||||
train_outputs, _ = sess.run(
|
||||
[outputs, train_op], {inputs[input_name]: input_arr,
|
||||
inputs[target_name]: target_arr})
|
||||
@ -402,7 +401,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
||||
output_path = keras_saved_model.save_keras_model(
|
||||
model, saved_model_path, custom_objects={'relu6': relu6})
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
inputs, outputs = load_model(sess, output_path,
|
||||
inputs, outputs, _ = load_model(sess, output_path,
|
||||
model_fn_lib.ModeKeys.PREDICT)
|
||||
input_name = model.input_names[0]
|
||||
output_name = model.output_names[0]
|
||||
|
@ -83,6 +83,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":constants",
|
||||
":signature_def_utils",
|
||||
":utils",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
@ -114,6 +115,7 @@ py_test(
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||
from tensorflow.python.training import saver as tf_saver
|
||||
from tensorflow.python.util import compat
|
||||
@ -40,6 +41,8 @@ from tensorflow.python.util.deprecation import deprecated_args
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# Base class for the SavedModelBuilder that is only used by Tensorflow
|
||||
# internally. Please use tf.compat.v1.saved_model.SavedModelBuilder instead.
|
||||
class _SavedModelBuilder(object):
|
||||
"""Builds the `SavedModel` protocol buffer and saves variables and assets.
|
||||
|
||||
@ -144,52 +147,6 @@ class _SavedModelBuilder(object):
|
||||
# Copy assets from source path to destination path.
|
||||
self._copy_assets_to_destination_dir(asset_filename_map)
|
||||
|
||||
def _maybe_add_main_op(self, main_op):
|
||||
"""Adds main op to the SavedModel.
|
||||
|
||||
Args:
|
||||
main_op: Main op to run as part of graph initialization. If None, no
|
||||
main op will be added to the graph.
|
||||
|
||||
Raises:
|
||||
TypeError: if main op is provided but is not of type `Operation`.
|
||||
ValueError: if the Graph already contains an init op.
|
||||
"""
|
||||
if main_op is None:
|
||||
return
|
||||
|
||||
if not isinstance(main_op, ops.Operation):
|
||||
raise TypeError("main_op needs to be an Operation: %r" % main_op)
|
||||
|
||||
# Validate that no other init ops have been added to this graph already.
|
||||
# We check main_op and legacy_init_op for thoroughness and explicitness.
|
||||
for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
|
||||
if ops.get_collection(init_op_key):
|
||||
raise ValueError(
|
||||
"Graph already contains one or more main ops under the "
|
||||
"collection {}.".format(init_op_key))
|
||||
|
||||
ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
|
||||
|
||||
def _add_train_op(self, train_op):
|
||||
"""Add train op to the SavedModel.
|
||||
|
||||
Note that this functionality is in development, and liable to be
|
||||
moved elsewhere.
|
||||
|
||||
Args:
|
||||
train_op: Op or group of ops that are used for training. These are
|
||||
stored as a collection with key TRAIN_OP_KEY, but not executed.
|
||||
|
||||
Raises:
|
||||
TypeError if Train op is not of type `Operation`.
|
||||
"""
|
||||
if train_op is not None:
|
||||
if (not isinstance(train_op, ops.Tensor) and
|
||||
not isinstance(train_op, ops.Operation)):
|
||||
raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op)
|
||||
ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
|
||||
|
||||
def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map):
|
||||
"""Tags the meta graph def and adds it to the SavedModel.
|
||||
|
||||
@ -245,12 +202,16 @@ class _SavedModelBuilder(object):
|
||||
|
||||
Validation of entries in the signature def map includes ensuring that the
|
||||
`name` and `dtype` fields of the TensorInfo protos of the `inputs` and
|
||||
`outputs` of each `SignatureDef` are populated.
|
||||
`outputs` of each `SignatureDef` are populated. Also ensures that reserved
|
||||
SigantureDef keys for the initialization and train ops are not used.
|
||||
|
||||
Args:
|
||||
signature_def_map: The map of signature defs to be validated.
|
||||
|
||||
Raises:
|
||||
AssertionError: If a TensorInfo is not valid.
|
||||
KeyError: If a reserved signature key is used in the map.
|
||||
"""
|
||||
if signature_def_map is not None:
|
||||
for signature_def_key in signature_def_map:
|
||||
signature_def = signature_def_map[signature_def_key]
|
||||
inputs = signature_def.inputs
|
||||
@ -259,12 +220,14 @@ class _SavedModelBuilder(object):
|
||||
self._validate_tensor_info(inputs[inputs_key])
|
||||
for outputs_key in outputs:
|
||||
self._validate_tensor_info(outputs[outputs_key])
|
||||
|
||||
def _add_collections(self, main_op, train_op):
|
||||
"""Add asset and op collections to be saved."""
|
||||
self._maybe_add_main_op(main_op)
|
||||
|
||||
self._add_train_op(train_op)
|
||||
if constants.INIT_OP_SIGNATURE_KEY in signature_def_map:
|
||||
raise KeyError(
|
||||
"SignatureDef map key \"{}\" is reserved for initialization. Please "
|
||||
"use a different key.".format(constants.INIT_OP_SIGNATURE_KEY))
|
||||
if constants.TRAIN_OP_SIGNATURE_KEY in signature_def_map:
|
||||
raise KeyError(
|
||||
"SignatureDef map key \"{}\" is reserved for the train op. Please "
|
||||
"use a different key.".format(constants.TRAIN_OP_SIGNATURE_KEY))
|
||||
|
||||
def _maybe_create_saver(self, saver=None):
|
||||
"""Creates a sharded saver if one does not already exist."""
|
||||
@ -278,19 +241,14 @@ class _SavedModelBuilder(object):
|
||||
allow_empty=True)
|
||||
return saver
|
||||
|
||||
@deprecated_args(None,
|
||||
"Pass your op to the equivalent parameter main_op instead.",
|
||||
"legacy_init_op")
|
||||
def add_meta_graph(self,
|
||||
tags,
|
||||
signature_def_map=None,
|
||||
assets_list=None,
|
||||
legacy_init_op=None,
|
||||
clear_devices=False,
|
||||
main_op=None,
|
||||
strip_default_attrs=False,
|
||||
init_op=None,
|
||||
train_op=None,
|
||||
saver=None):
|
||||
# pylint: disable=line-too-long
|
||||
"""Adds the current meta graph to the SavedModel.
|
||||
|
||||
Creates a Saver in the current scope and uses the Saver to export the meta
|
||||
@ -304,16 +262,14 @@ class _SavedModelBuilder(object):
|
||||
assets_list: Assets to be saved with SavedModel. Note
|
||||
that this list should be a subset of the assets saved as part of
|
||||
the first meta graph in the SavedModel.
|
||||
legacy_init_op: Legacy support for op or group of ops to execute after the
|
||||
restore op upon a load. Deprecated; please use main_op instead.
|
||||
clear_devices: Set to true if the device info on the default graph should
|
||||
be cleared.
|
||||
main_op: Op or group of ops to execute when the graph is loaded. Note
|
||||
that when the main_op is specified it is run after the restore op at
|
||||
init_op: Op or group of ops to execute when the graph is loaded. Note
|
||||
that when the init_op is specified it is run after the restore op at
|
||||
load-time.
|
||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||
removed from the NodeDefs. For a detailed guide, see
|
||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||
train_op: Op or group of opts that trains the model when run. This will
|
||||
not be run automatically when the graph is loaded, instead saved in
|
||||
a SignatureDef accessible through the exported MetaGraph.
|
||||
saver: An instance of tf.train.Saver that will be used to export the
|
||||
metagraph. If None, a sharded Saver that restores all variables will
|
||||
be used.
|
||||
@ -322,7 +278,6 @@ class _SavedModelBuilder(object):
|
||||
AssertionError: If the variables for the SavedModel have not been saved
|
||||
yet, or if the graph already contains one or more legacy init ops.
|
||||
"""
|
||||
# pylint: enable=line-too-long
|
||||
if not self._has_saved_variables:
|
||||
raise AssertionError(
|
||||
"Graph state including variables and assets has not been saved yet. "
|
||||
@ -330,14 +285,15 @@ class _SavedModelBuilder(object):
|
||||
|
||||
# Validate the signature def map to ensure all included TensorInfos are
|
||||
# properly populated.
|
||||
signature_def_map = signature_def_map or {}
|
||||
self._validate_signature_def_map(signature_def_map)
|
||||
|
||||
# legacy_init_op is deprecated, and going away in TF 2.0.
|
||||
# Re-mapping to main_op, as treatment is identical regardless.
|
||||
main_op = main_op or legacy_init_op
|
||||
|
||||
# Add ops to collection.
|
||||
self._add_collections(main_op=main_op, train_op=None)
|
||||
# Create a SignatureDef pointing to the graph initialization op, which will
|
||||
# be added to the MetaGraphDef.
|
||||
_add_op_to_signature_def_map(signature_def_map, init_op,
|
||||
constants.INIT_OP_SIGNATURE_KEY)
|
||||
_add_op_to_signature_def_map(signature_def_map, train_op,
|
||||
constants.TRAIN_OP_SIGNATURE_KEY)
|
||||
|
||||
saver = self._maybe_create_saver(saver)
|
||||
|
||||
@ -349,7 +305,7 @@ class _SavedModelBuilder(object):
|
||||
# resolved, we just leave the option set to False for now.
|
||||
# TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
|
||||
meta_graph_def = saver.export_meta_graph(
|
||||
clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
|
||||
clear_devices=clear_devices, strip_default_attrs=True)
|
||||
|
||||
# Save asset files and write them to disk, if any.
|
||||
self._save_and_write_assets(meta_graph_def, assets_list)
|
||||
@ -357,17 +313,14 @@ class _SavedModelBuilder(object):
|
||||
# Tag the meta graph def and add it to the SavedModel.
|
||||
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
|
||||
|
||||
@deprecated_args(None,
|
||||
"Pass your op to the equivalent parameter main_op instead.",
|
||||
"legacy_init_op")
|
||||
def add_meta_graph_and_variables(self,
|
||||
sess,
|
||||
tags,
|
||||
signature_def_map=None,
|
||||
assets_list=None,
|
||||
legacy_init_op=None,
|
||||
clear_devices=False,
|
||||
main_op=None,
|
||||
init_op=None,
|
||||
train_op=None,
|
||||
strip_default_attrs=False,
|
||||
saver=None):
|
||||
# pylint: disable=line-too-long
|
||||
@ -386,13 +339,14 @@ class _SavedModelBuilder(object):
|
||||
signature_def_map: The map of signature def map to add to the meta graph
|
||||
def.
|
||||
assets_list: Assets to be saved with SavedModel.
|
||||
legacy_init_op: Legacy support for op or group of ops to execute after the
|
||||
restore op upon a load. Deprecated; please use main_op instead.
|
||||
clear_devices: Set to true if the device info on the default graph should
|
||||
be cleared.
|
||||
main_op: Op or group of ops to execute when the graph is loaded. Note
|
||||
that when the main_op is specified it is run after the restore op at
|
||||
init_op: Op or group of ops to execute when the graph is loaded. Note
|
||||
that when the init_op is specified it is run after the restore op at
|
||||
load-time.
|
||||
train_op: Op or group of ops that trains the model when run. This will
|
||||
not be run automatically when the graph is loaded, instead saved in
|
||||
a SignatureDef accessible through the exported MetaGraph.
|
||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||
removed from the NodeDefs. For a detailed guide, see
|
||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||
@ -409,14 +363,15 @@ class _SavedModelBuilder(object):
|
||||
|
||||
# Validate the signature def map to ensure all included TensorInfos are
|
||||
# properly populated.
|
||||
signature_def_map = signature_def_map or {}
|
||||
self._validate_signature_def_map(signature_def_map)
|
||||
|
||||
# legacy_init_op is deprecated, and going away in TF 2.0.
|
||||
# Re-mapping to main_op, as treatment is identical regardless.
|
||||
main_op = main_op or legacy_init_op
|
||||
|
||||
# Add ops to collection.
|
||||
self._add_collections(main_op=main_op, train_op=None)
|
||||
# Create a SignatureDef pointing to the graph initialization op, which will
|
||||
# be added to the MetaGraphDef.
|
||||
_add_op_to_signature_def_map(signature_def_map, init_op,
|
||||
constants.INIT_OP_SIGNATURE_KEY)
|
||||
_add_op_to_signature_def_map(signature_def_map, train_op,
|
||||
constants.TRAIN_OP_SIGNATURE_KEY)
|
||||
|
||||
saved_model_utils.get_or_create_variables_dir(self._export_dir)
|
||||
variables_path = saved_model_utils.get_variables_path(self._export_dir)
|
||||
@ -517,6 +472,52 @@ class SavedModelBuilder(_SavedModelBuilder):
|
||||
# Copy assets from source path to destination path.
|
||||
self._copy_assets_to_destination_dir(asset_filename_map)
|
||||
|
||||
def _maybe_add_main_op(self, main_op):
|
||||
"""Adds main op to the SavedModel.
|
||||
|
||||
Args:
|
||||
main_op: Main op to run as part of graph initialization. If None, no main
|
||||
op will be added to the graph.
|
||||
|
||||
Raises:
|
||||
TypeError: if main op is provided but is not of type `Operation`.
|
||||
ValueError: if the Graph already contains an init op.
|
||||
"""
|
||||
if main_op is None:
|
||||
return
|
||||
|
||||
if not isinstance(main_op, ops.Operation):
|
||||
raise TypeError("main_op needs to be an Operation: %r" % main_op)
|
||||
|
||||
# Validate that no other init ops have been added to this graph already.
|
||||
# We check main_op and legacy_init_op for thoroughness and explicitness.
|
||||
for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
|
||||
if ops.get_collection(init_op_key):
|
||||
raise ValueError(
|
||||
"Graph already contains one or more main ops under the "
|
||||
"collection {}.".format(init_op_key))
|
||||
|
||||
ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
|
||||
|
||||
def _add_train_op(self, train_op):
|
||||
"""Add train op to the SavedModel.
|
||||
|
||||
Note that this functionality is in development, and liable to be
|
||||
moved elsewhere.
|
||||
|
||||
Args:
|
||||
train_op: Op or group of ops that are used for training. These are stored
|
||||
as a collection with key TRAIN_OP_KEY, but not executed.
|
||||
|
||||
Raises:
|
||||
TypeError if Train op is not of type `Operation`.
|
||||
"""
|
||||
if train_op is not None:
|
||||
if (not isinstance(train_op, ops.Tensor) and
|
||||
not isinstance(train_op, ops.Operation)):
|
||||
raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op)
|
||||
ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
|
||||
|
||||
@deprecated_args(None,
|
||||
"Pass your op to the equivalent parameter main_op instead.",
|
||||
"legacy_init_op")
|
||||
@ -536,6 +537,7 @@ class SavedModelBuilder(_SavedModelBuilder):
|
||||
|
||||
# Validate the signature def map to ensure all included TensorInfos are
|
||||
# properly populated.
|
||||
signature_def_map = signature_def_map or {}
|
||||
self._validate_signature_def_map(signature_def_map)
|
||||
|
||||
# legacy_init_op is deprecated, and going away in TF 2.0.
|
||||
@ -580,6 +582,7 @@ class SavedModelBuilder(_SavedModelBuilder):
|
||||
|
||||
# Validate the signature def map to ensure all included TensorInfos are
|
||||
# properly populated.
|
||||
signature_def_map = signature_def_map or {}
|
||||
self._validate_signature_def_map(signature_def_map)
|
||||
|
||||
# legacy_init_op is deprecated, and going away in TF 2.0.
|
||||
@ -774,3 +777,8 @@ def _add_asset_to_collection(asset_filename, asset_tensor):
|
||||
asset_any_proto = Any()
|
||||
asset_any_proto.Pack(asset_proto)
|
||||
ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)
|
||||
|
||||
|
||||
def _add_op_to_signature_def_map(signature_def_map, op, key):
|
||||
if op is not None:
|
||||
signature_def_map[key] = signature_def_utils.op_signature_def(op, key)
|
||||
|
@ -48,7 +48,6 @@ tf_export(
|
||||
# CollectionDef key for the SavedModel main op.
|
||||
MAIN_OP_KEY = "saved_model_main_op"
|
||||
tf_export(
|
||||
"saved_model.MAIN_OP_KEY",
|
||||
v1=["saved_model.MAIN_OP_KEY",
|
||||
"saved_model.constants.MAIN_OP_KEY"]).export_constant(
|
||||
__name__, "MAIN_OP_KEY")
|
||||
@ -105,3 +104,8 @@ tf_export(
|
||||
"saved_model.VARIABLES_FILENAME",
|
||||
"saved_model.constants.VARIABLES_FILENAME"
|
||||
]).export_constant(__name__, "VARIABLES_FILENAME")
|
||||
|
||||
# The initialization and train ops for a MetaGraph are stored in the
|
||||
# signature def map. The ops are added to the map with the following keys.
|
||||
INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"
|
||||
TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op"
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.saved_model import constants
|
||||
from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||
from tensorflow.python.training import saver as tf_saver
|
||||
from tensorflow.python.util import compat
|
||||
@ -141,15 +142,46 @@ def _get_main_op_tensor(
|
||||
RuntimeError: If the collection def corresponding to the main op key has
|
||||
other than exactly one tensor.
|
||||
"""
|
||||
# TODO(kathywu): Rename this method to _get_op_from_collection when
|
||||
# dependency from SavedModelEstimator is removed.
|
||||
collection_def = meta_graph_def_to_load.collection_def
|
||||
main_op_tensor = None
|
||||
init_op = None
|
||||
if init_op_key in collection_def:
|
||||
main_ops = collection_def[init_op_key].node_list.value
|
||||
if len(main_ops) != 1:
|
||||
raise RuntimeError("Expected exactly one SavedModel main op. "
|
||||
"Found: {}".format(main_ops))
|
||||
main_op_tensor = ops.get_collection(init_op_key)[0]
|
||||
return main_op_tensor
|
||||
init_op_list = collection_def[init_op_key].node_list.value
|
||||
if len(init_op_list) != 1:
|
||||
raise RuntimeError("Expected exactly one SavedModel init op. "
|
||||
"Found: {}".format(init_op_list))
|
||||
init_op = ops.get_collection(init_op_key)[0]
|
||||
return init_op
|
||||
|
||||
|
||||
def _get_op_from_collection(meta_graph_def, op_key):
|
||||
return _get_main_op_tensor(meta_graph_def, op_key)
|
||||
|
||||
|
||||
def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope):
|
||||
"""Retrieve op stored in the imported meta graph's signature def."""
|
||||
if op_signature_key in meta_graph_def.signature_def:
|
||||
return signature_def_utils.load_op_from_signature_def(
|
||||
meta_graph_def.signature_def[op_signature_key], op_signature_key,
|
||||
import_scope)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_init_op(meta_graph_def, import_scope=None):
|
||||
return (_get_op_from_signature_def(
|
||||
meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or
|
||||
_get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or
|
||||
_get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
|
||||
|
||||
|
||||
def get_train_op(meta_graph_def, import_scope=None):
|
||||
train_op = _get_op_from_signature_def(
|
||||
meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope)
|
||||
if train_op is None:
|
||||
train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY)
|
||||
return train_op
|
||||
|
||||
|
||||
@tf_export(v1=[
|
||||
@ -359,11 +391,9 @@ class SavedModelLoader(object):
|
||||
asset_tensors_dictionary = _get_asset_tensors(
|
||||
self._export_dir, meta_graph_def, import_scope=import_scope)
|
||||
|
||||
main_op_tensor = (
|
||||
_get_main_op_tensor(meta_graph_def, constants.MAIN_OP_KEY) or
|
||||
_get_main_op_tensor(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
|
||||
if main_op_tensor is not None:
|
||||
sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
|
||||
init_op = get_init_op(meta_graph_def, import_scope)
|
||||
if init_op is not None:
|
||||
sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary)
|
||||
|
||||
def load(self, sess, tags, import_scope=None, **saver_kwargs):
|
||||
"""Load the MetaGraphDef graph and restore variable values into the session.
|
||||
|
@ -19,11 +19,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -42,55 +44,72 @@ SIMPLE_ADD_SAVED_MODEL = _get_export_dir("simple_add_saved_model")
|
||||
SAVED_MODEL_WITH_MAIN_OP = _get_export_dir("saved_model_with_main_op")
|
||||
|
||||
|
||||
class SavedModelLoaderTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Write test SavedModels to a temp directory."""
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
def build_graph_helper():
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
x = variables.VariableV1(5, name="x")
|
||||
y = variables.VariableV1(11, name="y")
|
||||
z = x + y
|
||||
|
||||
foo_sig_def = signature_def_utils.build_signature_def({
|
||||
"foo_input": utils.build_tensor_info(x)
|
||||
}, {"foo_output": utils.build_tensor_info(z)})
|
||||
bar_sig_def = signature_def_utils.build_signature_def({
|
||||
"bar_x": utils.build_tensor_info(x),
|
||||
"bar_y": utils.build_tensor_info(y)
|
||||
}, {"bar_z": utils.build_tensor_info(z)})
|
||||
return g, {"foo": foo_sig_def, "bar": bar_sig_def}, y
|
||||
|
||||
|
||||
@parameterized.parameters((saved_model_builder.SavedModelBuilder,),
|
||||
(saved_model_builder._SavedModelBuilder,))
|
||||
class SavedModelLoaderTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def export_simple_graph(self, builder_cls):
|
||||
g, sig_def_map, _ = build_graph_helper()
|
||||
with session.Session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
foo_sig_def = signature_def_utils.build_signature_def(
|
||||
{"foo_input": utils.build_tensor_info(x)},
|
||||
{"foo_output": utils.build_tensor_info(z)})
|
||||
bar_sig_def = signature_def_utils.build_signature_def(
|
||||
{"bar_x": utils.build_tensor_info(x),
|
||||
"bar_y": utils.build_tensor_info(y)},
|
||||
{"bar_z": utils.build_tensor_info(z)})
|
||||
|
||||
builder = saved_model_builder.SavedModelBuilder(SIMPLE_ADD_SAVED_MODEL)
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def})
|
||||
builder = builder_cls(SIMPLE_ADD_SAVED_MODEL)
|
||||
builder.add_meta_graph_and_variables(sess, ["foo_graph"], sig_def_map)
|
||||
builder.save()
|
||||
|
||||
# Write SavedModel with a main_op
|
||||
def export_graph_with_main_op(self, builder_cls):
|
||||
g, sig_def_map, y = build_graph_helper()
|
||||
with session.Session(graph=g) as sess:
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
assign_op = control_flow_ops.group(state_ops.assign(y, 7))
|
||||
|
||||
builder = saved_model_builder.SavedModelBuilder(SAVED_MODEL_WITH_MAIN_OP)
|
||||
builder = builder_cls(SAVED_MODEL_WITH_MAIN_OP)
|
||||
|
||||
if builder_cls == saved_model_builder._SavedModelBuilder:
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def},
|
||||
main_op=assign_op)
|
||||
sess, ["foo_graph"], sig_def_map, init_op=assign_op)
|
||||
else:
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo_graph"], sig_def_map, main_op=assign_op)
|
||||
builder.save()
|
||||
|
||||
def tearDown(self):
|
||||
file_io.delete_recursively(test.get_temp_dir())
|
||||
super(SavedModelLoaderTest, self).tearDown()
|
||||
shutil.rmtree(test.get_temp_dir(), ignore_errors=True)
|
||||
|
||||
def test_load_function(self):
|
||||
def test_load_function(self, builder_cls):
|
||||
self.export_simple_graph(builder_cls)
|
||||
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
loader.load(sess, ["foo_graph"])
|
||||
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
||||
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
|
||||
|
||||
self.export_graph_with_main_op(builder_cls)
|
||||
loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
loader2.load(sess, ["foo_graph"])
|
||||
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
||||
self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
|
||||
|
||||
def test_load_graph(self):
|
||||
def test_load_graph(self, builder_cls):
|
||||
self.export_simple_graph(builder_cls)
|
||||
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
||||
graph = ops.Graph()
|
||||
loader.load_graph(graph, ["foo_graph"])
|
||||
@ -101,14 +120,15 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
with self.assertRaises(KeyError):
|
||||
graph.get_tensor_by_name("z:0")
|
||||
|
||||
with self.session(graph=graph) as sess:
|
||||
with self.session(graph=graph):
|
||||
# Check that x and y are not initialized
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
self.evaluate(x)
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
self.evaluate(y)
|
||||
|
||||
def test_load_with_import_scope(self):
|
||||
def test_load_with_import_scope(self, builder_cls):
|
||||
self.export_graph_with_main_op(builder_cls)
|
||||
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
saver, _ = loader.load_graph(
|
||||
@ -119,6 +139,12 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
loader.restore_variables(sess, tf_saver.Saver())
|
||||
|
||||
loader.restore_variables(sess, saver)
|
||||
|
||||
if builder_cls == saved_model_builder._SavedModelBuilder:
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
loader.run_init_ops(sess, ["foo_graph"])
|
||||
loader.run_init_ops(sess, ["foo_graph"], import_scope="baz")
|
||||
else:
|
||||
loader.run_init_ops(sess, ["foo_graph"])
|
||||
|
||||
self.assertEqual(5, sess.graph.get_tensor_by_name("baz/x:0").eval())
|
||||
@ -131,7 +157,8 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval())
|
||||
self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval())
|
||||
|
||||
def test_restore_variables(self):
|
||||
def test_restore_variables(self, builder_cls):
|
||||
self.export_graph_with_main_op(builder_cls)
|
||||
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
x = variables.VariableV1(0, name="x")
|
||||
@ -147,7 +174,8 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
loader.restore_variables(sess, tf_saver.Saver())
|
||||
self.assertEqual(55, self.evaluate(z))
|
||||
|
||||
def test_run_init_op(self):
|
||||
def test_run_init_op(self, builder_cls):
|
||||
self.export_graph_with_main_op(builder_cls)
|
||||
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
||||
graph = ops.Graph()
|
||||
saver, _ = loader.load_graph(graph, ["foo_graph"])
|
||||
@ -160,14 +188,16 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
||||
self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
|
||||
|
||||
def test_parse_saved_model(self):
|
||||
def test_parse_saved_model(self, builder_cls):
|
||||
self.export_simple_graph(builder_cls)
|
||||
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
||||
meta_graph = loader.get_meta_graph_def_from_tags(["foo_graph"])
|
||||
self.assertIsNotNone(meta_graph)
|
||||
self.assertIn("foo", meta_graph.signature_def)
|
||||
self.assertIn("bar", meta_graph.signature_def)
|
||||
|
||||
def test_load_invalid_meta_graph(self):
|
||||
def test_load_invalid_meta_graph(self, builder_cls):
|
||||
self.export_simple_graph(builder_cls)
|
||||
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
||||
with self.assertRaises(RuntimeError):
|
||||
loader.get_meta_graph_def_from_tags([])
|
||||
@ -176,13 +206,16 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
with self.assertRaises(RuntimeError):
|
||||
loader.get_meta_graph_def_from_tags(["not_a_graph"])
|
||||
|
||||
def test_load_saved_model_with_no_variables(self):
|
||||
def test_load_saved_model_with_no_variables(self, builder_cls):
|
||||
"""Test that SavedModel runs saver when there appear to be no variables.
|
||||
|
||||
When no variables are detected, this may mean that the variables were saved
|
||||
to different collections, or the collections weren't saved to the
|
||||
SavedModel. If the SavedModel MetaGraphDef contains a saver, it should still
|
||||
run in either of these cases.
|
||||
|
||||
Args:
|
||||
builder_cls: SavedModelBuilder or _SavedModelBuilder class
|
||||
"""
|
||||
path = _get_export_dir("no_variable_saved_model")
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
@ -192,7 +225,7 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
11, name="y", collections=["not_global_variable"])
|
||||
self.assertFalse(variables._all_saveable_objects())
|
||||
z = x + y
|
||||
sess.run(variables.variables_initializer([x, y]))
|
||||
self.evaluate(variables.variables_initializer([x, y]))
|
||||
|
||||
foo_sig_def = signature_def_utils.build_signature_def(
|
||||
{"foo_input": utils.build_tensor_info(x)},
|
||||
@ -215,8 +248,9 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
||||
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
|
||||
|
||||
def test_load_saved_model_graph_with_return_elements(self):
|
||||
def test_load_saved_model_graph_with_return_elements(self, builder_cls):
|
||||
"""Ensure that the correct elements are returned."""
|
||||
self.export_simple_graph(builder_cls)
|
||||
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
||||
graph = ops.Graph()
|
||||
_, ret = loader.load_graph(graph, ["foo_graph"],
|
||||
@ -228,5 +262,6 @@ class SavedModelLoaderTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, "not found in graph"):
|
||||
loader.load_graph(graph, ["foo_graph"], return_elements=["z:0"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -146,6 +146,18 @@ class SavedModelTest(SavedModelTestBase):
|
||||
sess, ["foo"],
|
||||
signature_def_map={"foo_key": foo_signature})
|
||||
|
||||
def _validate_sig_def_keys(self, builder, valid_tensor_info, invalid_key):
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
|
||||
foo_signature = signature_def_utils.build_signature_def(
|
||||
dict(), {"foo_key": valid_tensor_info}, "foo")
|
||||
self.assertRaises(
|
||||
KeyError,
|
||||
builder.add_meta_graph_and_variables,
|
||||
sess, ["foo"],
|
||||
signature_def_map={invalid_key: foo_signature})
|
||||
|
||||
def testMaybeSavedModelDir(self):
|
||||
base_path = test.test_src_dir_path("/python/saved_model")
|
||||
self.assertFalse(loader.maybe_saved_model_directory(base_path))
|
||||
@ -583,6 +595,15 @@ class SavedModelTest(SavedModelTestBase):
|
||||
self._validate_inputs_tensor_info_fail(builder, tensor_empty)
|
||||
self._validate_outputs_tensor_info_fail(builder, tensor_empty)
|
||||
|
||||
valid_tensor_info = meta_graph_pb2.TensorInfo()
|
||||
valid_tensor_info.name = "foo"
|
||||
valid_tensor_info.dtype = types_pb2.DT_FLOAT
|
||||
|
||||
self._validate_sig_def_keys(builder, valid_tensor_info,
|
||||
constants.INIT_OP_SIGNATURE_KEY)
|
||||
self._validate_sig_def_keys(builder, valid_tensor_info,
|
||||
constants.TRAIN_OP_SIGNATURE_KEY)
|
||||
|
||||
def testSignatureDefValidationSucceedsWithName(self):
|
||||
tensor_with_name = meta_graph_pb2.TensorInfo()
|
||||
tensor_with_name.name = "foo"
|
||||
@ -782,7 +803,7 @@ class SavedModelTest(SavedModelTestBase):
|
||||
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
||||
"foo bar baz 0", "asset_file_tensor_0:0")
|
||||
|
||||
def testCustomMainOp(self):
|
||||
def testCustomInitOp(self):
|
||||
export_dir = self._get_export_dir("test_main_op")
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
@ -800,11 +821,11 @@ class SavedModelTest(SavedModelTestBase):
|
||||
# Set up an assignment op to be run as part of the main_op.
|
||||
with ops.control_dependencies([main_op.main_op()]):
|
||||
add_v1_v2 = math_ops.add(v1._ref(), v2._ref())
|
||||
custom_main_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))
|
||||
custom_init_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))
|
||||
|
||||
self.evaluate(custom_main_op)
|
||||
self.evaluate(custom_init_op)
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], main_op=custom_main_op)
|
||||
sess, ["foo"], init_op=custom_init_op)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
@ -817,80 +838,6 @@ class SavedModelTest(SavedModelTestBase):
|
||||
# the main_op, following a restore.
|
||||
self.assertEqual(3, ops.get_collection("v")[2].eval())
|
||||
|
||||
def testLegacyInitOp(self):
|
||||
export_dir = self._get_export_dir("test_legacy_init_op")
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
# Add `v1` and `v2` variables to the graph.
|
||||
v1 = variables.VariableV1(1, name="v1")
|
||||
ops.add_to_collection("v", v1)
|
||||
v2 = variables.VariableV1(2, name="v2")
|
||||
ops.add_to_collection("v", v2)
|
||||
|
||||
# Initialize another variable `v3` to 42.
|
||||
v3 = variables.VariableV1(42, name="v3", trainable=False, collections=[])
|
||||
ops.add_to_collection("v", v3)
|
||||
|
||||
# Set up an assignment op to be run as part of the legacy_init_op.
|
||||
assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
|
||||
legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], legacy_init_op=legacy_init_op)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
loader.load(sess, ["foo"], export_dir)
|
||||
self.assertEqual(1, ops.get_collection("v")[0].eval())
|
||||
self.assertEqual(2, ops.get_collection("v")[1].eval())
|
||||
# Evaluates to the sum of the first two variables and assigned as part of
|
||||
# the legacy_init_op, following a restore.
|
||||
self.assertEqual(3, ops.get_collection("v")[2].eval())
|
||||
|
||||
def testLegacyInitOpWithNonEmptyCollection(self):
|
||||
export_dir = self._get_export_dir(
|
||||
"test_legacy_init_op_with_non_empty_collection")
|
||||
self._testInitOpsWithNonEmptyCollection(
|
||||
export_dir, constants.LEGACY_INIT_OP_KEY)
|
||||
|
||||
def testMainOpWithNonEmptyCollection(self):
|
||||
export_dir = self._get_export_dir(
|
||||
"test_main_op_with_non_empty_collection")
|
||||
self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY)
|
||||
|
||||
def _testInitOpsWithNonEmptyCollection(self, export_dir, key):
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
g = ops.Graph()
|
||||
with self.session(graph=g) as sess:
|
||||
# Initialize variable `v1` to 1.
|
||||
v1 = variables.VariableV1(1, name="v1")
|
||||
ops.add_to_collection("v", v1)
|
||||
|
||||
# Initialize another variable `v2` to 42.
|
||||
v2 = variables.VariableV1(42, name="v2", trainable=False, collections=[])
|
||||
ops.add_to_collection("v", v2)
|
||||
|
||||
# Set up an assignment op to be run as part of the init op.
|
||||
assign_v2 = state_ops.assign(v2, v1)
|
||||
init_op = control_flow_ops.group(assign_v2, name="init_op")
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
ops.add_to_collection(key, control_flow_ops.no_op())
|
||||
# ValueError should be raised since the LEGACY_INIT_OP_KEY collection
|
||||
# is not empty and we don't support multiple init ops.
|
||||
with self.assertRaisesRegexp(ValueError, "Graph already contains"):
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], legacy_init_op=init_op)
|
||||
# We shouldn't be able to add as MAIN_OP, either.
|
||||
with self.assertRaisesRegexp(ValueError, "Graph already contains"):
|
||||
builder.add_meta_graph_and_variables(sess, ["foo"], main_op=init_op)
|
||||
|
||||
def testTrainOp(self):
|
||||
export_dir = self._get_export_dir("test_train_op")
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
@ -906,19 +853,17 @@ class SavedModelTest(SavedModelTestBase):
|
||||
train_op = state_ops.assign_add(v1, v2)
|
||||
|
||||
self.evaluate(train_op)
|
||||
# TODO(karmel): remove explicit call when in the public method.
|
||||
builder._add_train_op(train_op)
|
||||
builder.add_meta_graph_and_variables(sess, ["foo"])
|
||||
builder.add_meta_graph_and_variables(sess, ["foo"], train_op=train_op)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
loader.load(sess, ["foo"], export_dir)
|
||||
meta_graph_def = loader.load(sess, ["foo"], export_dir)
|
||||
self.assertEqual(3, ops.get_collection("v")[0].eval())
|
||||
self.assertEqual(2, ops.get_collection("v")[1].eval())
|
||||
self.assertIsInstance(
|
||||
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
|
||||
loader_impl.get_train_op(meta_graph_def), ops.Tensor)
|
||||
|
||||
def testTrainOpGroup(self):
|
||||
export_dir = self._get_export_dir("test_train_op_group")
|
||||
@ -935,19 +880,17 @@ class SavedModelTest(SavedModelTestBase):
|
||||
train_op = control_flow_ops.group()
|
||||
|
||||
self.evaluate(train_op)
|
||||
# TODO(karmel): remove explicit call when in the public method.
|
||||
builder._add_train_op(train_op)
|
||||
builder.add_meta_graph_and_variables(sess, ["foo"])
|
||||
builder.add_meta_graph_and_variables(sess, ["foo"], train_op=train_op)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
loader.load(sess, ["foo"], export_dir)
|
||||
meta_graph_def = loader.load(sess, ["foo"], export_dir)
|
||||
self.assertEqual(1, ops.get_collection("v")[0].eval())
|
||||
self.assertEqual(2, ops.get_collection("v")[1].eval())
|
||||
self.assertIsInstance(
|
||||
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation)
|
||||
loader_impl.get_train_op(meta_graph_def), ops.Operation)
|
||||
|
||||
def testTrainOpAfterVariables(self):
|
||||
export_dir = self._get_export_dir("test_train_op_after_variables")
|
||||
@ -965,17 +908,15 @@ class SavedModelTest(SavedModelTestBase):
|
||||
|
||||
train_op = state_ops.assign_add(v1, v2)
|
||||
self.evaluate(train_op)
|
||||
# TODO(karmel): remove explicit call when in the public method.
|
||||
builder._add_train_op(train_op)
|
||||
builder.add_meta_graph(["foo"])
|
||||
builder.add_meta_graph(["foo"], train_op=train_op)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
loader.load(sess, ["foo"], export_dir)
|
||||
meta_graph_def = loader.load(sess, ["foo"], export_dir)
|
||||
self.assertIsInstance(
|
||||
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
|
||||
loader_impl.get_train_op(meta_graph_def), ops.Tensor)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
loader.load(sess, ["pre_foo"], export_dir)
|
||||
@ -1288,76 +1229,6 @@ class SavedModelTest(SavedModelTestBase):
|
||||
self.assertEqual(
|
||||
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
|
||||
|
||||
def testStripDefaultAttrs(self):
|
||||
export_dir = self._get_export_dir("test_strip_default_attrs")
|
||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||
|
||||
# Add a graph with two float32 variables and a Complex Op composing them
|
||||
# with strip_default_attrs enabled.
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
|
||||
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
|
||||
math_ops.complex(real_num, imag_num, name="complex")
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], strip_default_attrs=True)
|
||||
|
||||
# Add a graph with the same float32 variables and a Complex Op composing
|
||||
# them with strip_default_attrs disabled.
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
|
||||
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
|
||||
math_ops.complex(real_num, imag_num, name="complex")
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
builder.add_meta_graph(["bar"], strip_default_attrs=False)
|
||||
|
||||
# Save the SavedModel to disk in text format.
|
||||
builder.save(as_text=True)
|
||||
|
||||
# Loading graph "foo" via the loader must restore the defaults for the
|
||||
# "Complex" node based on the "Complex" OpDef in the Op registry.
|
||||
sess = session.Session(graph=ops.Graph())
|
||||
meta_graph_def = loader.load(sess, ["foo"], export_dir)
|
||||
complex_node = test_util.get_node_def_from_graph("complex",
|
||||
meta_graph_def.graph_def)
|
||||
self.assertIn("T", complex_node.attr)
|
||||
self.assertIn("Tout", complex_node.attr)
|
||||
|
||||
# Load graph "foo" from disk as-is to verify default attrs are stripped.
|
||||
# pylint: disable=protected-access
|
||||
saved_model_pb = loader_impl._parse_saved_model(export_dir)
|
||||
self.assertIsNotNone(saved_model_pb)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
meta_graph_foo_def = None
|
||||
meta_graph_bar_def = None
|
||||
for meta_graph_def in saved_model_pb.meta_graphs:
|
||||
if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
|
||||
meta_graph_foo_def = meta_graph_def
|
||||
elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
|
||||
meta_graph_bar_def = meta_graph_def
|
||||
|
||||
self.assertIsNotNone(meta_graph_foo_def)
|
||||
self.assertIsNotNone(meta_graph_bar_def)
|
||||
|
||||
# "Complex" Op has 2 attributes with defaults:
|
||||
# o "T" : float32. (input type)
|
||||
# o "Tout" : complex64. (output type)
|
||||
|
||||
# "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
|
||||
# Graph "foo" was saved with strip_default_attrs set to True.
|
||||
node_def = test_util.get_node_def_from_graph("complex",
|
||||
meta_graph_foo_def.graph_def)
|
||||
self.assertNotIn("T", node_def.attr)
|
||||
self.assertNotIn("Tout", node_def.attr)
|
||||
|
||||
# "Complex" Op in graph "bar" must have attributes "T" and "Tout".
|
||||
# Graph "bar" was saved with strip_default_attrs set to False.
|
||||
node_def = test_util.get_node_def_from_graph("complex",
|
||||
meta_graph_bar_def.graph_def)
|
||||
self.assertIn("T", node_def.attr)
|
||||
self.assertIn("Tout", node_def.attr)
|
||||
|
||||
# Tests the behavior of loading SavedModels that having missing attrs or attrs
|
||||
# with incorrect types.
|
||||
def testInconsistentConsumerDefaultAttrs(self):
|
||||
@ -1484,6 +1355,149 @@ class SavedModelV1Test(SavedModelTestBase):
|
||||
compat.as_bytes("ignored.txt"))
|
||||
self.assertFalse(file_io.file_exists(ignored_asset_path))
|
||||
|
||||
def testLegacyInitOpWithNonEmptyCollection(self):
|
||||
export_dir = self._get_export_dir(
|
||||
"test_legacy_init_op_with_non_empty_collection")
|
||||
self._testInitOpsWithNonEmptyCollection(export_dir,
|
||||
constants.LEGACY_INIT_OP_KEY)
|
||||
|
||||
def testMainOpWithNonEmptyCollection(self):
|
||||
export_dir = self._get_export_dir("test_main_op_with_non_empty_collection")
|
||||
self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY)
|
||||
|
||||
def _testInitOpsWithNonEmptyCollection(self, export_dir, key):
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
||||
g = ops.Graph()
|
||||
with self.session(graph=g) as sess:
|
||||
# Initialize variable `v1` to 1.
|
||||
v1 = variables.VariableV1(1, name="v1")
|
||||
ops.add_to_collection("v", v1)
|
||||
|
||||
# Initialize another variable `v2` to 42.
|
||||
v2 = variables.VariableV1(42, name="v2", trainable=False, collections=[])
|
||||
ops.add_to_collection("v", v2)
|
||||
|
||||
# Set up an assignment op to be run as part of the init op.
|
||||
assign_v2 = state_ops.assign(v2, v1)
|
||||
init_op = control_flow_ops.group(assign_v2, name="init_op")
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
ops.add_to_collection(key, control_flow_ops.no_op())
|
||||
# ValueError should be raised since the LEGACY_INIT_OP_KEY collection
|
||||
# is not empty and we don't support multiple init ops.
|
||||
with self.assertRaisesRegexp(ValueError, "Graph already contains"):
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], legacy_init_op=init_op)
|
||||
# We shouldn't be able to add as MAIN_OP, either.
|
||||
with self.assertRaisesRegexp(ValueError, "Graph already contains"):
|
||||
builder.add_meta_graph_and_variables(sess, ["foo"], main_op=init_op)
|
||||
|
||||
def testStripDefaultAttrs(self):
|
||||
export_dir = self._get_export_dir("test_strip_default_attrs")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
||||
# Add a graph with two float32 variables and a Complex Op composing them
|
||||
# with strip_default_attrs enabled.
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
|
||||
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
|
||||
math_ops.complex(real_num, imag_num, name="complex")
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], strip_default_attrs=True)
|
||||
|
||||
# Add a graph with the same float32 variables and a Complex Op composing
|
||||
# them with strip_default_attrs disabled.
|
||||
with session.Session(graph=ops.Graph()) as sess:
|
||||
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
|
||||
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
|
||||
math_ops.complex(real_num, imag_num, name="complex")
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
builder.add_meta_graph(["bar"], strip_default_attrs=False)
|
||||
|
||||
# Save the SavedModel to disk in text format.
|
||||
builder.save(as_text=True)
|
||||
|
||||
# Loading graph "foo" via the loader must restore the defaults for the
|
||||
# "Complex" node based on the "Complex" OpDef in the Op registry.
|
||||
sess = session.Session(graph=ops.Graph())
|
||||
meta_graph_def = loader.load(sess, ["foo"], export_dir)
|
||||
complex_node = test_util.get_node_def_from_graph("complex",
|
||||
meta_graph_def.graph_def)
|
||||
self.assertIn("T", complex_node.attr)
|
||||
self.assertIn("Tout", complex_node.attr)
|
||||
|
||||
# Load graph "foo" from disk as-is to verify default attrs are stripped.
|
||||
# pylint: disable=protected-access
|
||||
saved_model_pb = loader_impl._parse_saved_model(export_dir)
|
||||
self.assertIsNotNone(saved_model_pb)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
meta_graph_foo_def = None
|
||||
meta_graph_bar_def = None
|
||||
for meta_graph_def in saved_model_pb.meta_graphs:
|
||||
if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
|
||||
meta_graph_foo_def = meta_graph_def
|
||||
elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
|
||||
meta_graph_bar_def = meta_graph_def
|
||||
|
||||
self.assertIsNotNone(meta_graph_foo_def)
|
||||
self.assertIsNotNone(meta_graph_bar_def)
|
||||
|
||||
# "Complex" Op has 2 attributes with defaults:
|
||||
# o "T" : float32. (input type)
|
||||
# o "Tout" : complex64. (output type)
|
||||
|
||||
# "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
|
||||
# Graph "foo" was saved with strip_default_attrs set to True.
|
||||
node_def = test_util.get_node_def_from_graph("complex",
|
||||
meta_graph_foo_def.graph_def)
|
||||
self.assertNotIn("T", node_def.attr)
|
||||
self.assertNotIn("Tout", node_def.attr)
|
||||
|
||||
# "Complex" Op in graph "bar" must have attributes "T" and "Tout".
|
||||
# Graph "bar" was saved with strip_default_attrs set to False.
|
||||
node_def = test_util.get_node_def_from_graph("complex",
|
||||
meta_graph_bar_def.graph_def)
|
||||
self.assertIn("T", node_def.attr)
|
||||
self.assertIn("Tout", node_def.attr)
|
||||
|
||||
def testLegacyInitOp(self):
|
||||
export_dir = self._get_export_dir("test_legacy_init_op")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
# Add `v1` and `v2` variables to the graph.
|
||||
v1 = variables.VariableV1(1, name="v1")
|
||||
ops.add_to_collection("v", v1)
|
||||
v2 = variables.VariableV1(2, name="v2")
|
||||
ops.add_to_collection("v", v2)
|
||||
|
||||
# Initialize another variable `v3` to 42.
|
||||
v3 = variables.VariableV1(42, name="v3", trainable=False, collections=[])
|
||||
ops.add_to_collection("v", v3)
|
||||
|
||||
# Set up an assignment op to be run as part of the init_op.
|
||||
assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
|
||||
legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["foo"], legacy_init_op=legacy_init_op)
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
with self.session(graph=ops.Graph()) as sess:
|
||||
loader.load(sess, ["foo"], export_dir)
|
||||
self.assertEqual(1, ops.get_collection("v")[0].eval())
|
||||
self.assertEqual(2, ops.get_collection("v")[1].eval())
|
||||
# Evaluates to the sum of the first two variables and assigned as part of
|
||||
# the legacy_init_op, following a restore.
|
||||
self.assertEqual(3, ops.get_collection("v")[2].eval())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -24,6 +24,8 @@ from __future__ import print_function
|
||||
from tensorflow.python.saved_model.signature_def_utils_impl import build_signature_def
|
||||
from tensorflow.python.saved_model.signature_def_utils_impl import classification_signature_def
|
||||
from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature
|
||||
from tensorflow.python.saved_model.signature_def_utils_impl import load_op_from_signature_def
|
||||
from tensorflow.python.saved_model.signature_def_utils_impl import op_signature_def
|
||||
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
|
||||
from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def
|
||||
from tensorflow.python.saved_model.signature_def_utils_impl import supervised_eval_signature_def
|
||||
|
@ -21,9 +21,10 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import utils
|
||||
from tensorflow.python.saved_model import utils_impl as utils
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -349,3 +350,51 @@ def _is_valid_classification_signature(signature_def):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def op_signature_def(op, key):
|
||||
"""Creates a signature def with the output pointing to an op.
|
||||
|
||||
Note that op isn't strictly enforced to be an Op object, and may be a Tensor.
|
||||
It is recommended to use the build_signature_def() function for Tensors.
|
||||
|
||||
Args:
|
||||
op: An Op (or possibly Tensor).
|
||||
key: Key to graph element in the SignatureDef outputs.
|
||||
|
||||
Returns:
|
||||
A SignatureDef with a single output pointing to the op.
|
||||
"""
|
||||
# Use build_tensor_info_from_op, which creates a TensorInfo from the element's
|
||||
# name.
|
||||
return build_signature_def(outputs={key: utils.build_tensor_info_from_op(op)})
|
||||
|
||||
|
||||
def load_op_from_signature_def(signature_def, key, import_scope=None):
|
||||
"""Load an Op from a SignatureDef created by op_signature_def().
|
||||
|
||||
Args:
|
||||
signature_def: a SignatureDef proto
|
||||
key: string key to op in the SignatureDef outputs.
|
||||
import_scope: Scope used to import the op
|
||||
|
||||
Returns:
|
||||
Op (or possibly Tensor) in the graph with the same name as saved in the
|
||||
SignatureDef.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If the op could not be found in the graph.
|
||||
"""
|
||||
tensor_info = signature_def.outputs[key]
|
||||
try:
|
||||
# The init and train ops are not strictly enforced to be operations, so
|
||||
# retrieve any graph element (can be either op or tensor).
|
||||
return utils.get_element_from_tensor_info(
|
||||
tensor_info, import_scope=import_scope)
|
||||
except KeyError:
|
||||
raise errors.NotFoundError(
|
||||
None, None,
|
||||
'The {0} could not be found in the graph. Please make sure the '
|
||||
'SavedModel was created by the internal _SavedModelBuilder. If you '
|
||||
'are using the public API, please make sure the SignatureDef in the '
|
||||
'SavedModel does not contain the key "{0}".'.format(key))
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.saved_model import signature_def_utils_impl
|
||||
@ -413,5 +414,22 @@ class SignatureDefUtilsTest(test.TestCase):
|
||||
{},
|
||||
signature_constants.PREDICT_METHOD_NAME)
|
||||
|
||||
def testOpSignatureDef(self):
|
||||
key = "adding_1_and_2_key"
|
||||
add_op = math_ops.add(1, 2, name="adding_1_and_2")
|
||||
signature_def = signature_def_utils_impl.op_signature_def(add_op, key)
|
||||
self.assertIn(key, signature_def.outputs)
|
||||
self.assertEqual(add_op.name, signature_def.outputs[key].name)
|
||||
|
||||
def testLoadOpFromSignatureDef(self):
|
||||
key = "adding_1_and_2_key"
|
||||
add_op = math_ops.add(1, 2, name="adding_1_and_2")
|
||||
signature_def = signature_def_utils_impl.op_signature_def(add_op, key)
|
||||
|
||||
self.assertEqual(
|
||||
add_op,
|
||||
signature_def_utils_impl.load_op_from_signature_def(signature_def, key))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -141,6 +141,27 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
|
||||
raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
|
||||
|
||||
|
||||
def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None):
|
||||
"""Returns the element in the graph described by a TensorInfo proto.
|
||||
|
||||
Args:
|
||||
tensor_info: A TensorInfo proto describing an Op or Tensor by name.
|
||||
graph: The tf.Graph in which tensors are looked up. If None, the current
|
||||
default graph is used.
|
||||
import_scope: If not None, names in `tensor_info` are prefixed with this
|
||||
string before lookup.
|
||||
|
||||
Returns:
|
||||
Op or tensor in `graph` described by `tensor_info`.
|
||||
|
||||
Raises:
|
||||
KeyError: If `tensor_info` does not correspond to an op or tensor in `graph`
|
||||
"""
|
||||
graph = graph or ops.get_default_graph()
|
||||
return graph.as_graph_element(
|
||||
ops.prepend_name_scope(tensor_info.name, import_scope=import_scope))
|
||||
|
||||
|
||||
# Path helpers.
|
||||
|
||||
|
||||
|
@ -32,10 +32,6 @@ tf_module {
|
||||
name: "GPU"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "MAIN_OP_KEY"
|
||||
mtype: "<type \'str\'>"
|
||||
}
|
||||
member {
|
||||
name: "PREDICT_INPUTS"
|
||||
mtype: "<type \'str\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user