Save main_op and train_op in SignatureDefs instead of collections.
PiperOrigin-RevId: 223268983
This commit is contained in:
parent
99c20bf32e
commit
693c1ad4a9
@ -133,5 +133,6 @@ filegroup(
|
|||||||
"testdata/half_plus_two_pbtxt/**",
|
"testdata/half_plus_two_pbtxt/**",
|
||||||
"testdata/half_plus_two_main_op/**",
|
"testdata/half_plus_two_main_op/**",
|
||||||
"testdata/half_plus_two/**",
|
"testdata/half_plus_two/**",
|
||||||
|
"testdata/half_plus_two_v2/**",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
@ -33,10 +33,10 @@ constexpr char kSavedModelFilenamePb[] = "saved_model.pb";
|
|||||||
/// SavedModel text format proto filename.
|
/// SavedModel text format proto filename.
|
||||||
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
|
constexpr char kSavedModelFilenamePbTxt[] = "saved_model.pbtxt";
|
||||||
|
|
||||||
/// SavedModel legacy init op key.
|
/// SavedModel legacy init op collection key. Used in v1 SavedModels.
|
||||||
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
|
constexpr char kSavedModelLegacyInitOpKey[] = "legacy_init_op";
|
||||||
|
|
||||||
/// SavedModel main op key.
|
/// SavedModel main op collection key. Used in v1 SavedModels.
|
||||||
constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
|
constexpr char kSavedModelMainOpKey[] = "saved_model_main_op";
|
||||||
|
|
||||||
/// Directory in which to save the SavedModel variables.
|
/// Directory in which to save the SavedModel variables.
|
||||||
@ -45,6 +45,11 @@ constexpr char kSavedModelVariablesDirectory[] = "variables";
|
|||||||
/// SavedModel variables filename.
|
/// SavedModel variables filename.
|
||||||
constexpr char kSavedModelVariablesFilename[] = "variables";
|
constexpr char kSavedModelVariablesFilename[] = "variables";
|
||||||
|
|
||||||
|
/// SavedModel SignatureDef keys for the initialization and train ops. Used in
|
||||||
|
/// V2 SavedModels.
|
||||||
|
constexpr char kSavedModelInitOpSignatureKey[] = "__saved_model_init_op";
|
||||||
|
constexpr char kSavedModelTrainOpSignatureKey[] = "__saved_model_train_op";
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
|
#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_
|
||||||
|
@ -122,38 +122,58 @@ Status RunOnce(const RunOptions& run_options,
|
|||||||
return run_status;
|
return run_status;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HasMainOp(const MetaGraphDef& meta_graph_def) {
|
// RunInitOp will return OK if the initialization op was run successfully.
|
||||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
// An empty init_op_name indicates that there are no init ops to run.
|
||||||
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
Status RunInitOp(const RunOptions& run_options, const string& export_dir,
|
||||||
collection_def_map.end()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status RunMainOp(const RunOptions& run_options, const string& export_dir,
|
|
||||||
const MetaGraphDef& meta_graph_def,
|
const MetaGraphDef& meta_graph_def,
|
||||||
const std::vector<AssetFileDef>& asset_file_defs,
|
const std::vector<AssetFileDef>& asset_file_defs,
|
||||||
Session* session, const string& main_op_key) {
|
Session* session, const string& init_op_name) {
|
||||||
LOG(INFO) << "Running MainOp with key " << main_op_key
|
if (!init_op_name.empty()) {
|
||||||
<< " on SavedModel bundle.";
|
LOG(INFO) << "Running initialization op on SavedModel bundle.";
|
||||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
|
||||||
const auto main_op_it = collection_def_map.find(main_op_key);
|
|
||||||
if (main_op_it != collection_def_map.end()) {
|
|
||||||
if (main_op_it->second.node_list().value_size() != 1) {
|
|
||||||
return errors::FailedPrecondition(
|
|
||||||
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
|
||||||
}
|
|
||||||
std::vector<std::pair<string, Tensor>> inputs;
|
std::vector<std::pair<string, Tensor>> inputs;
|
||||||
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
|
AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
|
||||||
RunMetadata run_metadata;
|
RunMetadata run_metadata;
|
||||||
const StringPiece main_op_name = main_op_it->second.node_list().value(0);
|
return RunOnce(run_options, inputs, {}, {init_op_name},
|
||||||
return RunOnce(run_options, inputs, {}, {string(main_op_name)},
|
|
||||||
nullptr /* outputs */, &run_metadata, session);
|
nullptr /* outputs */, &run_metadata, session);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// A SavedModel may store the name of the initialization op to run in the
|
||||||
|
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||||
|
// exists, then the collection must contain exactly one op.
|
||||||
|
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||||
|
string* init_op_name) {
|
||||||
|
const auto& sig_def_map = meta_graph_def.signature_def();
|
||||||
|
const auto& init_op_sig_it =
|
||||||
|
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
|
||||||
|
if (init_op_sig_it != sig_def_map.end()) {
|
||||||
|
*init_op_name = init_op_sig_it->second.outputs()
|
||||||
|
.find(kSavedModelInitOpSignatureKey)
|
||||||
|
->second.name();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||||
|
string init_op_collection_key;
|
||||||
|
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
||||||
|
collection_def_map.end()) {
|
||||||
|
init_op_collection_key = kSavedModelMainOpKey;
|
||||||
|
} else {
|
||||||
|
init_op_collection_key = kSavedModelLegacyInitOpKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto init_op_it = collection_def_map.find(init_op_collection_key);
|
||||||
|
if (init_op_it != collection_def_map.end()) {
|
||||||
|
if (init_op_it->second.node_list().value_size() != 1) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
||||||
|
}
|
||||||
|
*init_op_name = init_op_it->second.node_list().value(0);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||||
const StringPiece restore_op_name,
|
const StringPiece restore_op_name,
|
||||||
const StringPiece variable_filename_const_op_name,
|
const StringPiece variable_filename_const_op_name,
|
||||||
@ -236,15 +256,12 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
|||||||
bundle->meta_graph_def.saver_def().restore_op_name(),
|
bundle->meta_graph_def.saver_def().restore_op_name(),
|
||||||
bundle->meta_graph_def.saver_def().filename_tensor_name(),
|
bundle->meta_graph_def.saver_def().filename_tensor_name(),
|
||||||
asset_file_defs, bundle->session.get()));
|
asset_file_defs, bundle->session.get()));
|
||||||
if (HasMainOp(bundle->meta_graph_def)) {
|
string init_op_name;
|
||||||
TF_RETURN_IF_ERROR(RunMainOp(run_options, export_dir,
|
TF_RETURN_IF_ERROR(
|
||||||
bundle->meta_graph_def, asset_file_defs,
|
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||||
bundle->session.get(), kSavedModelMainOpKey));
|
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
|
||||||
} else {
|
asset_file_defs, bundle->session.get(),
|
||||||
TF_RETURN_IF_ERROR(RunMainOp(
|
init_op_name));
|
||||||
run_options, export_dir, bundle->meta_graph_def, asset_file_defs,
|
|
||||||
bundle->session.get(), kSavedModelLegacyInitOpKey));
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -36,6 +36,8 @@ constexpr char kTestDataMainOp[] =
|
|||||||
"cc/saved_model/testdata/half_plus_two_main_op/00000123";
|
"cc/saved_model/testdata/half_plus_two_main_op/00000123";
|
||||||
constexpr char kTestDataSharded[] =
|
constexpr char kTestDataSharded[] =
|
||||||
"cc/saved_model/testdata/half_plus_two/00000123";
|
"cc/saved_model/testdata/half_plus_two/00000123";
|
||||||
|
constexpr char kTestDataInitOpV2[] =
|
||||||
|
"cc/saved_model/testdata/half_plus_two_v2/00000123";
|
||||||
|
|
||||||
class LoaderTest : public ::testing::Test {
|
class LoaderTest : public ::testing::Test {
|
||||||
protected:
|
protected:
|
||||||
@ -227,5 +229,17 @@ TEST_F(LoaderTest, MaybeSavedModelDirectory) {
|
|||||||
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
|
EXPECT_FALSE(MaybeSavedModelDirectory(invalid_export_dir));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(LoaderTest, SavedModelInitOpV2Format) {
|
||||||
|
SavedModelBundle bundle;
|
||||||
|
SessionOptions session_options;
|
||||||
|
RunOptions run_options;
|
||||||
|
|
||||||
|
const string export_dir =
|
||||||
|
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataInitOpV2);
|
||||||
|
TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir,
|
||||||
|
{kSavedModelTagServe}, &bundle));
|
||||||
|
CheckSavedModelBundle(export_dir, bundle);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
1
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt
vendored
Normal file
1
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/assets/foo.txt
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
asset-file-contents
|
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/saved_model.pb
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.data-00000-of-00001
vendored
Normal file
Binary file not shown.
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index
vendored
Normal file
BIN
tensorflow/cc/saved_model/testdata/half_plus_two_v2/00000123/variables/variables.index
vendored
Normal file
Binary file not shown.
@ -125,7 +125,7 @@ def save_keras_model(
|
|||||||
export_dir = export_helpers.get_timestamped_export_dir(saved_model_path)
|
export_dir = export_helpers.get_timestamped_export_dir(saved_model_path)
|
||||||
temp_export_dir = export_helpers.get_temp_export_dir(export_dir)
|
temp_export_dir = export_helpers.get_temp_export_dir(export_dir)
|
||||||
|
|
||||||
builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
|
builder = saved_model_builder._SavedModelBuilder(temp_export_dir)
|
||||||
|
|
||||||
# Manually save variables to export them in an object-based checkpoint. This
|
# Manually save variables to export them in an object-based checkpoint. This
|
||||||
# skips the `builder.add_meta_graph_and_variables()` step, which saves a
|
# skips the `builder.add_meta_graph_and_variables()` step, which saves a
|
||||||
@ -227,9 +227,10 @@ def _export_mode(
|
|||||||
g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
|
g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
|
||||||
|
|
||||||
# Extract update and train ops from train/test/predict functions.
|
# Extract update and train ops from train/test/predict functions.
|
||||||
|
train_op = None
|
||||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||||
clone._make_train_function()
|
clone._make_train_function()
|
||||||
builder._add_train_op(clone.train_function.updates_op)
|
train_op = clone.train_function.updates_op
|
||||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||||
clone._make_test_function()
|
clone._make_test_function()
|
||||||
else:
|
else:
|
||||||
@ -264,7 +265,8 @@ def _export_mode(
|
|||||||
model_fn_lib.EXPORT_TAG_MAP[mode],
|
model_fn_lib.EXPORT_TAG_MAP[mode],
|
||||||
signature_def_map=_create_signature_def_map(clone, mode),
|
signature_def_map=_create_signature_def_map(clone, mode),
|
||||||
saver=saver_lib.Saver(clone_var_list),
|
saver=saver_lib.Saver(clone_var_list),
|
||||||
main_op=variables.local_variables_initializer())
|
init_op=variables.local_variables_initializer(),
|
||||||
|
train_op=train_op)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,7 +35,6 @@ from tensorflow.python.keras.engine import training
|
|||||||
from tensorflow.python.keras.utils import tf_utils
|
from tensorflow.python.keras.utils import tf_utils
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.saved_model import constants
|
|
||||||
from tensorflow.python.saved_model import loader_impl
|
from tensorflow.python.saved_model import loader_impl
|
||||||
from tensorflow.python.saved_model import signature_constants
|
from tensorflow.python.saved_model import signature_constants
|
||||||
from tensorflow.python.training import training as training_module
|
from tensorflow.python.training import training as training_module
|
||||||
@ -254,7 +253,7 @@ def load_model(sess, path, mode):
|
|||||||
outputs = {
|
outputs = {
|
||||||
k: sess.graph.get_tensor_by_name(v.name)
|
k: sess.graph.get_tensor_by_name(v.name)
|
||||||
for k, v in meta_graph_def.signature_def[sig_def_key].outputs.items()}
|
for k, v in meta_graph_def.signature_def[sig_def_key].outputs.items()}
|
||||||
return inputs, outputs
|
return inputs, outputs, meta_graph_def
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
@ -331,7 +330,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
# Load predict graph, and test predictions
|
# Load predict graph, and test predictions
|
||||||
with session.Session(graph=ops.Graph()) as sess:
|
with session.Session(graph=ops.Graph()) as sess:
|
||||||
inputs, outputs = load_model(sess, output_path,
|
inputs, outputs, _ = load_model(sess, output_path,
|
||||||
model_fn_lib.ModeKeys.PREDICT)
|
model_fn_lib.ModeKeys.PREDICT)
|
||||||
|
|
||||||
predictions = sess.run(outputs[output_name],
|
predictions = sess.run(outputs[output_name],
|
||||||
@ -341,7 +340,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
|||||||
if optimizer:
|
if optimizer:
|
||||||
# Load eval graph, and test predictions, loss and metric values
|
# Load eval graph, and test predictions, loss and metric values
|
||||||
with session.Session(graph=ops.Graph()) as sess:
|
with session.Session(graph=ops.Graph()) as sess:
|
||||||
inputs, outputs = load_model(sess, output_path,
|
inputs, outputs, _ = load_model(sess, output_path,
|
||||||
model_fn_lib.ModeKeys.EVAL)
|
model_fn_lib.ModeKeys.EVAL)
|
||||||
|
|
||||||
# First obtain the loss and predictions, and run the metric update op by
|
# First obtain the loss and predictions, and run the metric update op by
|
||||||
@ -365,8 +364,8 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
# Load train graph, and check for the train op, and prediction values
|
# Load train graph, and check for the train op, and prediction values
|
||||||
with session.Session(graph=ops.Graph()) as sess:
|
with session.Session(graph=ops.Graph()) as sess:
|
||||||
inputs, outputs = load_model(sess, output_path,
|
inputs, outputs, meta_graph_def = load_model(
|
||||||
model_fn_lib.ModeKeys.TRAIN)
|
sess, output_path, model_fn_lib.ModeKeys.TRAIN)
|
||||||
self.assertEqual(int(train_before_export),
|
self.assertEqual(int(train_before_export),
|
||||||
sess.run(training_module.get_global_step()))
|
sess.run(training_module.get_global_step()))
|
||||||
self.assertIn('loss', outputs)
|
self.assertIn('loss', outputs)
|
||||||
@ -375,7 +374,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertIn('predictions/' + output_name, outputs)
|
self.assertIn('predictions/' + output_name, outputs)
|
||||||
|
|
||||||
# Train for a step
|
# Train for a step
|
||||||
train_op = ops.get_collection(constants.TRAIN_OP_KEY)
|
train_op = loader_impl.get_train_op(meta_graph_def)
|
||||||
train_outputs, _ = sess.run(
|
train_outputs, _ = sess.run(
|
||||||
[outputs, train_op], {inputs[input_name]: input_arr,
|
[outputs, train_op], {inputs[input_name]: input_arr,
|
||||||
inputs[target_name]: target_arr})
|
inputs[target_name]: target_arr})
|
||||||
@ -402,7 +401,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
|||||||
output_path = keras_saved_model.save_keras_model(
|
output_path = keras_saved_model.save_keras_model(
|
||||||
model, saved_model_path, custom_objects={'relu6': relu6})
|
model, saved_model_path, custom_objects={'relu6': relu6})
|
||||||
with session.Session(graph=ops.Graph()) as sess:
|
with session.Session(graph=ops.Graph()) as sess:
|
||||||
inputs, outputs = load_model(sess, output_path,
|
inputs, outputs, _ = load_model(sess, output_path,
|
||||||
model_fn_lib.ModeKeys.PREDICT)
|
model_fn_lib.ModeKeys.PREDICT)
|
||||||
input_name = model.input_names[0]
|
input_name = model.input_names[0]
|
||||||
output_name = model.output_names[0]
|
output_name = model.output_names[0]
|
||||||
|
@ -83,6 +83,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":constants",
|
":constants",
|
||||||
|
":signature_def_utils",
|
||||||
":utils",
|
":utils",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
"//tensorflow/python:framework_for_generated_wrappers",
|
"//tensorflow/python:framework_for_generated_wrappers",
|
||||||
@ -114,6 +115,7 @@ py_test(
|
|||||||
"//tensorflow/python:state_ops",
|
"//tensorflow/python:state_ops",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.python.lib.io import file_io
|
|||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
from tensorflow.python.saved_model import constants
|
from tensorflow.python.saved_model import constants
|
||||||
|
from tensorflow.python.saved_model import signature_def_utils
|
||||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||||
from tensorflow.python.training import saver as tf_saver
|
from tensorflow.python.training import saver as tf_saver
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
@ -40,6 +41,8 @@ from tensorflow.python.util.deprecation import deprecated_args
|
|||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
# Base class for the SavedModelBuilder that is only used by Tensorflow
|
||||||
|
# internally. Please use tf.compat.v1.saved_model.SavedModelBuilder instead.
|
||||||
class _SavedModelBuilder(object):
|
class _SavedModelBuilder(object):
|
||||||
"""Builds the `SavedModel` protocol buffer and saves variables and assets.
|
"""Builds the `SavedModel` protocol buffer and saves variables and assets.
|
||||||
|
|
||||||
@ -144,52 +147,6 @@ class _SavedModelBuilder(object):
|
|||||||
# Copy assets from source path to destination path.
|
# Copy assets from source path to destination path.
|
||||||
self._copy_assets_to_destination_dir(asset_filename_map)
|
self._copy_assets_to_destination_dir(asset_filename_map)
|
||||||
|
|
||||||
def _maybe_add_main_op(self, main_op):
|
|
||||||
"""Adds main op to the SavedModel.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
main_op: Main op to run as part of graph initialization. If None, no
|
|
||||||
main op will be added to the graph.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: if main op is provided but is not of type `Operation`.
|
|
||||||
ValueError: if the Graph already contains an init op.
|
|
||||||
"""
|
|
||||||
if main_op is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not isinstance(main_op, ops.Operation):
|
|
||||||
raise TypeError("main_op needs to be an Operation: %r" % main_op)
|
|
||||||
|
|
||||||
# Validate that no other init ops have been added to this graph already.
|
|
||||||
# We check main_op and legacy_init_op for thoroughness and explicitness.
|
|
||||||
for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
|
|
||||||
if ops.get_collection(init_op_key):
|
|
||||||
raise ValueError(
|
|
||||||
"Graph already contains one or more main ops under the "
|
|
||||||
"collection {}.".format(init_op_key))
|
|
||||||
|
|
||||||
ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
|
|
||||||
|
|
||||||
def _add_train_op(self, train_op):
|
|
||||||
"""Add train op to the SavedModel.
|
|
||||||
|
|
||||||
Note that this functionality is in development, and liable to be
|
|
||||||
moved elsewhere.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
train_op: Op or group of ops that are used for training. These are
|
|
||||||
stored as a collection with key TRAIN_OP_KEY, but not executed.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError if Train op is not of type `Operation`.
|
|
||||||
"""
|
|
||||||
if train_op is not None:
|
|
||||||
if (not isinstance(train_op, ops.Tensor) and
|
|
||||||
not isinstance(train_op, ops.Operation)):
|
|
||||||
raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op)
|
|
||||||
ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
|
|
||||||
|
|
||||||
def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map):
|
def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map):
|
||||||
"""Tags the meta graph def and adds it to the SavedModel.
|
"""Tags the meta graph def and adds it to the SavedModel.
|
||||||
|
|
||||||
@ -245,12 +202,16 @@ class _SavedModelBuilder(object):
|
|||||||
|
|
||||||
Validation of entries in the signature def map includes ensuring that the
|
Validation of entries in the signature def map includes ensuring that the
|
||||||
`name` and `dtype` fields of the TensorInfo protos of the `inputs` and
|
`name` and `dtype` fields of the TensorInfo protos of the `inputs` and
|
||||||
`outputs` of each `SignatureDef` are populated.
|
`outputs` of each `SignatureDef` are populated. Also ensures that reserved
|
||||||
|
SigantureDef keys for the initialization and train ops are not used.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
signature_def_map: The map of signature defs to be validated.
|
signature_def_map: The map of signature defs to be validated.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If a TensorInfo is not valid.
|
||||||
|
KeyError: If a reserved signature key is used in the map.
|
||||||
"""
|
"""
|
||||||
if signature_def_map is not None:
|
|
||||||
for signature_def_key in signature_def_map:
|
for signature_def_key in signature_def_map:
|
||||||
signature_def = signature_def_map[signature_def_key]
|
signature_def = signature_def_map[signature_def_key]
|
||||||
inputs = signature_def.inputs
|
inputs = signature_def.inputs
|
||||||
@ -259,12 +220,14 @@ class _SavedModelBuilder(object):
|
|||||||
self._validate_tensor_info(inputs[inputs_key])
|
self._validate_tensor_info(inputs[inputs_key])
|
||||||
for outputs_key in outputs:
|
for outputs_key in outputs:
|
||||||
self._validate_tensor_info(outputs[outputs_key])
|
self._validate_tensor_info(outputs[outputs_key])
|
||||||
|
if constants.INIT_OP_SIGNATURE_KEY in signature_def_map:
|
||||||
def _add_collections(self, main_op, train_op):
|
raise KeyError(
|
||||||
"""Add asset and op collections to be saved."""
|
"SignatureDef map key \"{}\" is reserved for initialization. Please "
|
||||||
self._maybe_add_main_op(main_op)
|
"use a different key.".format(constants.INIT_OP_SIGNATURE_KEY))
|
||||||
|
if constants.TRAIN_OP_SIGNATURE_KEY in signature_def_map:
|
||||||
self._add_train_op(train_op)
|
raise KeyError(
|
||||||
|
"SignatureDef map key \"{}\" is reserved for the train op. Please "
|
||||||
|
"use a different key.".format(constants.TRAIN_OP_SIGNATURE_KEY))
|
||||||
|
|
||||||
def _maybe_create_saver(self, saver=None):
|
def _maybe_create_saver(self, saver=None):
|
||||||
"""Creates a sharded saver if one does not already exist."""
|
"""Creates a sharded saver if one does not already exist."""
|
||||||
@ -278,19 +241,14 @@ class _SavedModelBuilder(object):
|
|||||||
allow_empty=True)
|
allow_empty=True)
|
||||||
return saver
|
return saver
|
||||||
|
|
||||||
@deprecated_args(None,
|
|
||||||
"Pass your op to the equivalent parameter main_op instead.",
|
|
||||||
"legacy_init_op")
|
|
||||||
def add_meta_graph(self,
|
def add_meta_graph(self,
|
||||||
tags,
|
tags,
|
||||||
signature_def_map=None,
|
signature_def_map=None,
|
||||||
assets_list=None,
|
assets_list=None,
|
||||||
legacy_init_op=None,
|
|
||||||
clear_devices=False,
|
clear_devices=False,
|
||||||
main_op=None,
|
init_op=None,
|
||||||
strip_default_attrs=False,
|
train_op=None,
|
||||||
saver=None):
|
saver=None):
|
||||||
# pylint: disable=line-too-long
|
|
||||||
"""Adds the current meta graph to the SavedModel.
|
"""Adds the current meta graph to the SavedModel.
|
||||||
|
|
||||||
Creates a Saver in the current scope and uses the Saver to export the meta
|
Creates a Saver in the current scope and uses the Saver to export the meta
|
||||||
@ -304,16 +262,14 @@ class _SavedModelBuilder(object):
|
|||||||
assets_list: Assets to be saved with SavedModel. Note
|
assets_list: Assets to be saved with SavedModel. Note
|
||||||
that this list should be a subset of the assets saved as part of
|
that this list should be a subset of the assets saved as part of
|
||||||
the first meta graph in the SavedModel.
|
the first meta graph in the SavedModel.
|
||||||
legacy_init_op: Legacy support for op or group of ops to execute after the
|
|
||||||
restore op upon a load. Deprecated; please use main_op instead.
|
|
||||||
clear_devices: Set to true if the device info on the default graph should
|
clear_devices: Set to true if the device info on the default graph should
|
||||||
be cleared.
|
be cleared.
|
||||||
main_op: Op or group of ops to execute when the graph is loaded. Note
|
init_op: Op or group of ops to execute when the graph is loaded. Note
|
||||||
that when the main_op is specified it is run after the restore op at
|
that when the init_op is specified it is run after the restore op at
|
||||||
load-time.
|
load-time.
|
||||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
train_op: Op or group of opts that trains the model when run. This will
|
||||||
removed from the NodeDefs. For a detailed guide, see
|
not be run automatically when the graph is loaded, instead saved in
|
||||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
a SignatureDef accessible through the exported MetaGraph.
|
||||||
saver: An instance of tf.train.Saver that will be used to export the
|
saver: An instance of tf.train.Saver that will be used to export the
|
||||||
metagraph. If None, a sharded Saver that restores all variables will
|
metagraph. If None, a sharded Saver that restores all variables will
|
||||||
be used.
|
be used.
|
||||||
@ -322,7 +278,6 @@ class _SavedModelBuilder(object):
|
|||||||
AssertionError: If the variables for the SavedModel have not been saved
|
AssertionError: If the variables for the SavedModel have not been saved
|
||||||
yet, or if the graph already contains one or more legacy init ops.
|
yet, or if the graph already contains one or more legacy init ops.
|
||||||
"""
|
"""
|
||||||
# pylint: enable=line-too-long
|
|
||||||
if not self._has_saved_variables:
|
if not self._has_saved_variables:
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
"Graph state including variables and assets has not been saved yet. "
|
"Graph state including variables and assets has not been saved yet. "
|
||||||
@ -330,14 +285,15 @@ class _SavedModelBuilder(object):
|
|||||||
|
|
||||||
# Validate the signature def map to ensure all included TensorInfos are
|
# Validate the signature def map to ensure all included TensorInfos are
|
||||||
# properly populated.
|
# properly populated.
|
||||||
|
signature_def_map = signature_def_map or {}
|
||||||
self._validate_signature_def_map(signature_def_map)
|
self._validate_signature_def_map(signature_def_map)
|
||||||
|
|
||||||
# legacy_init_op is deprecated, and going away in TF 2.0.
|
# Create a SignatureDef pointing to the graph initialization op, which will
|
||||||
# Re-mapping to main_op, as treatment is identical regardless.
|
# be added to the MetaGraphDef.
|
||||||
main_op = main_op or legacy_init_op
|
_add_op_to_signature_def_map(signature_def_map, init_op,
|
||||||
|
constants.INIT_OP_SIGNATURE_KEY)
|
||||||
# Add ops to collection.
|
_add_op_to_signature_def_map(signature_def_map, train_op,
|
||||||
self._add_collections(main_op=main_op, train_op=None)
|
constants.TRAIN_OP_SIGNATURE_KEY)
|
||||||
|
|
||||||
saver = self._maybe_create_saver(saver)
|
saver = self._maybe_create_saver(saver)
|
||||||
|
|
||||||
@ -349,7 +305,7 @@ class _SavedModelBuilder(object):
|
|||||||
# resolved, we just leave the option set to False for now.
|
# resolved, we just leave the option set to False for now.
|
||||||
# TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
|
# TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
|
||||||
meta_graph_def = saver.export_meta_graph(
|
meta_graph_def = saver.export_meta_graph(
|
||||||
clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
|
clear_devices=clear_devices, strip_default_attrs=True)
|
||||||
|
|
||||||
# Save asset files and write them to disk, if any.
|
# Save asset files and write them to disk, if any.
|
||||||
self._save_and_write_assets(meta_graph_def, assets_list)
|
self._save_and_write_assets(meta_graph_def, assets_list)
|
||||||
@ -357,17 +313,14 @@ class _SavedModelBuilder(object):
|
|||||||
# Tag the meta graph def and add it to the SavedModel.
|
# Tag the meta graph def and add it to the SavedModel.
|
||||||
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
|
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
|
||||||
|
|
||||||
@deprecated_args(None,
|
|
||||||
"Pass your op to the equivalent parameter main_op instead.",
|
|
||||||
"legacy_init_op")
|
|
||||||
def add_meta_graph_and_variables(self,
|
def add_meta_graph_and_variables(self,
|
||||||
sess,
|
sess,
|
||||||
tags,
|
tags,
|
||||||
signature_def_map=None,
|
signature_def_map=None,
|
||||||
assets_list=None,
|
assets_list=None,
|
||||||
legacy_init_op=None,
|
|
||||||
clear_devices=False,
|
clear_devices=False,
|
||||||
main_op=None,
|
init_op=None,
|
||||||
|
train_op=None,
|
||||||
strip_default_attrs=False,
|
strip_default_attrs=False,
|
||||||
saver=None):
|
saver=None):
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
@ -386,13 +339,14 @@ class _SavedModelBuilder(object):
|
|||||||
signature_def_map: The map of signature def map to add to the meta graph
|
signature_def_map: The map of signature def map to add to the meta graph
|
||||||
def.
|
def.
|
||||||
assets_list: Assets to be saved with SavedModel.
|
assets_list: Assets to be saved with SavedModel.
|
||||||
legacy_init_op: Legacy support for op or group of ops to execute after the
|
|
||||||
restore op upon a load. Deprecated; please use main_op instead.
|
|
||||||
clear_devices: Set to true if the device info on the default graph should
|
clear_devices: Set to true if the device info on the default graph should
|
||||||
be cleared.
|
be cleared.
|
||||||
main_op: Op or group of ops to execute when the graph is loaded. Note
|
init_op: Op or group of ops to execute when the graph is loaded. Note
|
||||||
that when the main_op is specified it is run after the restore op at
|
that when the init_op is specified it is run after the restore op at
|
||||||
load-time.
|
load-time.
|
||||||
|
train_op: Op or group of ops that trains the model when run. This will
|
||||||
|
not be run automatically when the graph is loaded, instead saved in
|
||||||
|
a SignatureDef accessible through the exported MetaGraph.
|
||||||
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
strip_default_attrs: Boolean. If `True`, default-valued attributes will be
|
||||||
removed from the NodeDefs. For a detailed guide, see
|
removed from the NodeDefs. For a detailed guide, see
|
||||||
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
[Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
|
||||||
@ -409,14 +363,15 @@ class _SavedModelBuilder(object):
|
|||||||
|
|
||||||
# Validate the signature def map to ensure all included TensorInfos are
|
# Validate the signature def map to ensure all included TensorInfos are
|
||||||
# properly populated.
|
# properly populated.
|
||||||
|
signature_def_map = signature_def_map or {}
|
||||||
self._validate_signature_def_map(signature_def_map)
|
self._validate_signature_def_map(signature_def_map)
|
||||||
|
|
||||||
# legacy_init_op is deprecated, and going away in TF 2.0.
|
# Create a SignatureDef pointing to the graph initialization op, which will
|
||||||
# Re-mapping to main_op, as treatment is identical regardless.
|
# be added to the MetaGraphDef.
|
||||||
main_op = main_op or legacy_init_op
|
_add_op_to_signature_def_map(signature_def_map, init_op,
|
||||||
|
constants.INIT_OP_SIGNATURE_KEY)
|
||||||
# Add ops to collection.
|
_add_op_to_signature_def_map(signature_def_map, train_op,
|
||||||
self._add_collections(main_op=main_op, train_op=None)
|
constants.TRAIN_OP_SIGNATURE_KEY)
|
||||||
|
|
||||||
saved_model_utils.get_or_create_variables_dir(self._export_dir)
|
saved_model_utils.get_or_create_variables_dir(self._export_dir)
|
||||||
variables_path = saved_model_utils.get_variables_path(self._export_dir)
|
variables_path = saved_model_utils.get_variables_path(self._export_dir)
|
||||||
@ -517,6 +472,52 @@ class SavedModelBuilder(_SavedModelBuilder):
|
|||||||
# Copy assets from source path to destination path.
|
# Copy assets from source path to destination path.
|
||||||
self._copy_assets_to_destination_dir(asset_filename_map)
|
self._copy_assets_to_destination_dir(asset_filename_map)
|
||||||
|
|
||||||
|
def _maybe_add_main_op(self, main_op):
|
||||||
|
"""Adds main op to the SavedModel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
main_op: Main op to run as part of graph initialization. If None, no main
|
||||||
|
op will be added to the graph.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if main op is provided but is not of type `Operation`.
|
||||||
|
ValueError: if the Graph already contains an init op.
|
||||||
|
"""
|
||||||
|
if main_op is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(main_op, ops.Operation):
|
||||||
|
raise TypeError("main_op needs to be an Operation: %r" % main_op)
|
||||||
|
|
||||||
|
# Validate that no other init ops have been added to this graph already.
|
||||||
|
# We check main_op and legacy_init_op for thoroughness and explicitness.
|
||||||
|
for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
|
||||||
|
if ops.get_collection(init_op_key):
|
||||||
|
raise ValueError(
|
||||||
|
"Graph already contains one or more main ops under the "
|
||||||
|
"collection {}.".format(init_op_key))
|
||||||
|
|
||||||
|
ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
|
||||||
|
|
||||||
|
def _add_train_op(self, train_op):
|
||||||
|
"""Add train op to the SavedModel.
|
||||||
|
|
||||||
|
Note that this functionality is in development, and liable to be
|
||||||
|
moved elsewhere.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_op: Op or group of ops that are used for training. These are stored
|
||||||
|
as a collection with key TRAIN_OP_KEY, but not executed.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError if Train op is not of type `Operation`.
|
||||||
|
"""
|
||||||
|
if train_op is not None:
|
||||||
|
if (not isinstance(train_op, ops.Tensor) and
|
||||||
|
not isinstance(train_op, ops.Operation)):
|
||||||
|
raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op)
|
||||||
|
ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
|
||||||
|
|
||||||
@deprecated_args(None,
|
@deprecated_args(None,
|
||||||
"Pass your op to the equivalent parameter main_op instead.",
|
"Pass your op to the equivalent parameter main_op instead.",
|
||||||
"legacy_init_op")
|
"legacy_init_op")
|
||||||
@ -536,6 +537,7 @@ class SavedModelBuilder(_SavedModelBuilder):
|
|||||||
|
|
||||||
# Validate the signature def map to ensure all included TensorInfos are
|
# Validate the signature def map to ensure all included TensorInfos are
|
||||||
# properly populated.
|
# properly populated.
|
||||||
|
signature_def_map = signature_def_map or {}
|
||||||
self._validate_signature_def_map(signature_def_map)
|
self._validate_signature_def_map(signature_def_map)
|
||||||
|
|
||||||
# legacy_init_op is deprecated, and going away in TF 2.0.
|
# legacy_init_op is deprecated, and going away in TF 2.0.
|
||||||
@ -580,6 +582,7 @@ class SavedModelBuilder(_SavedModelBuilder):
|
|||||||
|
|
||||||
# Validate the signature def map to ensure all included TensorInfos are
|
# Validate the signature def map to ensure all included TensorInfos are
|
||||||
# properly populated.
|
# properly populated.
|
||||||
|
signature_def_map = signature_def_map or {}
|
||||||
self._validate_signature_def_map(signature_def_map)
|
self._validate_signature_def_map(signature_def_map)
|
||||||
|
|
||||||
# legacy_init_op is deprecated, and going away in TF 2.0.
|
# legacy_init_op is deprecated, and going away in TF 2.0.
|
||||||
@ -774,3 +777,8 @@ def _add_asset_to_collection(asset_filename, asset_tensor):
|
|||||||
asset_any_proto = Any()
|
asset_any_proto = Any()
|
||||||
asset_any_proto.Pack(asset_proto)
|
asset_any_proto.Pack(asset_proto)
|
||||||
ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)
|
ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_op_to_signature_def_map(signature_def_map, op, key):
|
||||||
|
if op is not None:
|
||||||
|
signature_def_map[key] = signature_def_utils.op_signature_def(op, key)
|
||||||
|
@ -48,7 +48,6 @@ tf_export(
|
|||||||
# CollectionDef key for the SavedModel main op.
|
# CollectionDef key for the SavedModel main op.
|
||||||
MAIN_OP_KEY = "saved_model_main_op"
|
MAIN_OP_KEY = "saved_model_main_op"
|
||||||
tf_export(
|
tf_export(
|
||||||
"saved_model.MAIN_OP_KEY",
|
|
||||||
v1=["saved_model.MAIN_OP_KEY",
|
v1=["saved_model.MAIN_OP_KEY",
|
||||||
"saved_model.constants.MAIN_OP_KEY"]).export_constant(
|
"saved_model.constants.MAIN_OP_KEY"]).export_constant(
|
||||||
__name__, "MAIN_OP_KEY")
|
__name__, "MAIN_OP_KEY")
|
||||||
@ -105,3 +104,8 @@ tf_export(
|
|||||||
"saved_model.VARIABLES_FILENAME",
|
"saved_model.VARIABLES_FILENAME",
|
||||||
"saved_model.constants.VARIABLES_FILENAME"
|
"saved_model.constants.VARIABLES_FILENAME"
|
||||||
]).export_constant(__name__, "VARIABLES_FILENAME")
|
]).export_constant(__name__, "VARIABLES_FILENAME")
|
||||||
|
|
||||||
|
# The initialization and train ops for a MetaGraph are stored in the
|
||||||
|
# signature def map. The ops are added to the map with the following keys.
|
||||||
|
INIT_OP_SIGNATURE_KEY = "__saved_model_init_op"
|
||||||
|
TRAIN_OP_SIGNATURE_KEY = "__saved_model_train_op"
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.lib.io import file_io
|
|||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
from tensorflow.python.saved_model import constants
|
from tensorflow.python.saved_model import constants
|
||||||
|
from tensorflow.python.saved_model import signature_def_utils
|
||||||
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
from tensorflow.python.saved_model import utils_impl as saved_model_utils
|
||||||
from tensorflow.python.training import saver as tf_saver
|
from tensorflow.python.training import saver as tf_saver
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
@ -141,15 +142,46 @@ def _get_main_op_tensor(
|
|||||||
RuntimeError: If the collection def corresponding to the main op key has
|
RuntimeError: If the collection def corresponding to the main op key has
|
||||||
other than exactly one tensor.
|
other than exactly one tensor.
|
||||||
"""
|
"""
|
||||||
|
# TODO(kathywu): Rename this method to _get_op_from_collection when
|
||||||
|
# dependency from SavedModelEstimator is removed.
|
||||||
collection_def = meta_graph_def_to_load.collection_def
|
collection_def = meta_graph_def_to_load.collection_def
|
||||||
main_op_tensor = None
|
init_op = None
|
||||||
if init_op_key in collection_def:
|
if init_op_key in collection_def:
|
||||||
main_ops = collection_def[init_op_key].node_list.value
|
init_op_list = collection_def[init_op_key].node_list.value
|
||||||
if len(main_ops) != 1:
|
if len(init_op_list) != 1:
|
||||||
raise RuntimeError("Expected exactly one SavedModel main op. "
|
raise RuntimeError("Expected exactly one SavedModel init op. "
|
||||||
"Found: {}".format(main_ops))
|
"Found: {}".format(init_op_list))
|
||||||
main_op_tensor = ops.get_collection(init_op_key)[0]
|
init_op = ops.get_collection(init_op_key)[0]
|
||||||
return main_op_tensor
|
return init_op
|
||||||
|
|
||||||
|
|
||||||
|
def _get_op_from_collection(meta_graph_def, op_key):
|
||||||
|
return _get_main_op_tensor(meta_graph_def, op_key)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_op_from_signature_def(meta_graph_def, op_signature_key, import_scope):
|
||||||
|
"""Retrieve op stored in the imported meta graph's signature def."""
|
||||||
|
if op_signature_key in meta_graph_def.signature_def:
|
||||||
|
return signature_def_utils.load_op_from_signature_def(
|
||||||
|
meta_graph_def.signature_def[op_signature_key], op_signature_key,
|
||||||
|
import_scope)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_init_op(meta_graph_def, import_scope=None):
|
||||||
|
return (_get_op_from_signature_def(
|
||||||
|
meta_graph_def, constants.INIT_OP_SIGNATURE_KEY, import_scope) or
|
||||||
|
_get_op_from_collection(meta_graph_def, constants.MAIN_OP_KEY) or
|
||||||
|
_get_op_from_collection(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_op(meta_graph_def, import_scope=None):
|
||||||
|
train_op = _get_op_from_signature_def(
|
||||||
|
meta_graph_def, constants.TRAIN_OP_SIGNATURE_KEY, import_scope)
|
||||||
|
if train_op is None:
|
||||||
|
train_op = _get_op_from_collection(meta_graph_def, constants.TRAIN_OP_KEY)
|
||||||
|
return train_op
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=[
|
@tf_export(v1=[
|
||||||
@ -359,11 +391,9 @@ class SavedModelLoader(object):
|
|||||||
asset_tensors_dictionary = _get_asset_tensors(
|
asset_tensors_dictionary = _get_asset_tensors(
|
||||||
self._export_dir, meta_graph_def, import_scope=import_scope)
|
self._export_dir, meta_graph_def, import_scope=import_scope)
|
||||||
|
|
||||||
main_op_tensor = (
|
init_op = get_init_op(meta_graph_def, import_scope)
|
||||||
_get_main_op_tensor(meta_graph_def, constants.MAIN_OP_KEY) or
|
if init_op is not None:
|
||||||
_get_main_op_tensor(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
|
sess.run(fetches=[init_op], feed_dict=asset_tensors_dictionary)
|
||||||
if main_op_tensor is not None:
|
|
||||||
sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
|
|
||||||
|
|
||||||
def load(self, sess, tags, import_scope=None, **saver_kwargs):
|
def load(self, sess, tags, import_scope=None, **saver_kwargs):
|
||||||
"""Load the MetaGraphDef graph and restore variable values into the session.
|
"""Load the MetaGraphDef graph and restore variable values into the session.
|
||||||
|
@ -19,11 +19,13 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.lib.io import file_io
|
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
@ -42,55 +44,72 @@ SIMPLE_ADD_SAVED_MODEL = _get_export_dir("simple_add_saved_model")
|
|||||||
SAVED_MODEL_WITH_MAIN_OP = _get_export_dir("saved_model_with_main_op")
|
SAVED_MODEL_WITH_MAIN_OP = _get_export_dir("saved_model_with_main_op")
|
||||||
|
|
||||||
|
|
||||||
class SavedModelLoaderTest(test.TestCase):
|
def build_graph_helper():
|
||||||
|
g = ops.Graph()
|
||||||
def setUp(self):
|
with g.as_default():
|
||||||
"""Write test SavedModels to a temp directory."""
|
|
||||||
with session.Session(graph=ops.Graph()) as sess:
|
|
||||||
x = variables.VariableV1(5, name="x")
|
x = variables.VariableV1(5, name="x")
|
||||||
y = variables.VariableV1(11, name="y")
|
y = variables.VariableV1(11, name="y")
|
||||||
z = x + y
|
z = x + y
|
||||||
|
|
||||||
|
foo_sig_def = signature_def_utils.build_signature_def({
|
||||||
|
"foo_input": utils.build_tensor_info(x)
|
||||||
|
}, {"foo_output": utils.build_tensor_info(z)})
|
||||||
|
bar_sig_def = signature_def_utils.build_signature_def({
|
||||||
|
"bar_x": utils.build_tensor_info(x),
|
||||||
|
"bar_y": utils.build_tensor_info(y)
|
||||||
|
}, {"bar_z": utils.build_tensor_info(z)})
|
||||||
|
return g, {"foo": foo_sig_def, "bar": bar_sig_def}, y
|
||||||
|
|
||||||
|
|
||||||
|
@parameterized.parameters((saved_model_builder.SavedModelBuilder,),
|
||||||
|
(saved_model_builder._SavedModelBuilder,))
|
||||||
|
class SavedModelLoaderTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def export_simple_graph(self, builder_cls):
|
||||||
|
g, sig_def_map, _ = build_graph_helper()
|
||||||
|
with session.Session(graph=g) as sess:
|
||||||
self.evaluate(variables.global_variables_initializer())
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
builder = builder_cls(SIMPLE_ADD_SAVED_MODEL)
|
||||||
foo_sig_def = signature_def_utils.build_signature_def(
|
builder.add_meta_graph_and_variables(sess, ["foo_graph"], sig_def_map)
|
||||||
{"foo_input": utils.build_tensor_info(x)},
|
|
||||||
{"foo_output": utils.build_tensor_info(z)})
|
|
||||||
bar_sig_def = signature_def_utils.build_signature_def(
|
|
||||||
{"bar_x": utils.build_tensor_info(x),
|
|
||||||
"bar_y": utils.build_tensor_info(y)},
|
|
||||||
{"bar_z": utils.build_tensor_info(z)})
|
|
||||||
|
|
||||||
builder = saved_model_builder.SavedModelBuilder(SIMPLE_ADD_SAVED_MODEL)
|
|
||||||
builder.add_meta_graph_and_variables(
|
|
||||||
sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def})
|
|
||||||
builder.save()
|
builder.save()
|
||||||
|
|
||||||
# Write SavedModel with a main_op
|
def export_graph_with_main_op(self, builder_cls):
|
||||||
|
g, sig_def_map, y = build_graph_helper()
|
||||||
|
with session.Session(graph=g) as sess:
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
assign_op = control_flow_ops.group(state_ops.assign(y, 7))
|
assign_op = control_flow_ops.group(state_ops.assign(y, 7))
|
||||||
|
|
||||||
builder = saved_model_builder.SavedModelBuilder(SAVED_MODEL_WITH_MAIN_OP)
|
builder = builder_cls(SAVED_MODEL_WITH_MAIN_OP)
|
||||||
|
|
||||||
|
if builder_cls == saved_model_builder._SavedModelBuilder:
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def},
|
sess, ["foo_graph"], sig_def_map, init_op=assign_op)
|
||||||
main_op=assign_op)
|
else:
|
||||||
|
builder.add_meta_graph_and_variables(
|
||||||
|
sess, ["foo_graph"], sig_def_map, main_op=assign_op)
|
||||||
builder.save()
|
builder.save()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
file_io.delete_recursively(test.get_temp_dir())
|
super(SavedModelLoaderTest, self).tearDown()
|
||||||
|
shutil.rmtree(test.get_temp_dir(), ignore_errors=True)
|
||||||
|
|
||||||
def test_load_function(self):
|
def test_load_function(self, builder_cls):
|
||||||
|
self.export_simple_graph(builder_cls)
|
||||||
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
loader.load(sess, ["foo_graph"])
|
loader.load(sess, ["foo_graph"])
|
||||||
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
||||||
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
|
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
|
||||||
|
|
||||||
|
self.export_graph_with_main_op(builder_cls)
|
||||||
loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
loader2 = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
loader2.load(sess, ["foo_graph"])
|
loader2.load(sess, ["foo_graph"])
|
||||||
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
||||||
self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
|
self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
|
||||||
|
|
||||||
def test_load_graph(self):
|
def test_load_graph(self, builder_cls):
|
||||||
|
self.export_simple_graph(builder_cls)
|
||||||
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
loader.load_graph(graph, ["foo_graph"])
|
loader.load_graph(graph, ["foo_graph"])
|
||||||
@ -101,14 +120,15 @@ class SavedModelLoaderTest(test.TestCase):
|
|||||||
with self.assertRaises(KeyError):
|
with self.assertRaises(KeyError):
|
||||||
graph.get_tensor_by_name("z:0")
|
graph.get_tensor_by_name("z:0")
|
||||||
|
|
||||||
with self.session(graph=graph) as sess:
|
with self.session(graph=graph):
|
||||||
# Check that x and y are not initialized
|
# Check that x and y are not initialized
|
||||||
with self.assertRaises(errors.FailedPreconditionError):
|
with self.assertRaises(errors.FailedPreconditionError):
|
||||||
self.evaluate(x)
|
self.evaluate(x)
|
||||||
with self.assertRaises(errors.FailedPreconditionError):
|
with self.assertRaises(errors.FailedPreconditionError):
|
||||||
self.evaluate(y)
|
self.evaluate(y)
|
||||||
|
|
||||||
def test_load_with_import_scope(self):
|
def test_load_with_import_scope(self, builder_cls):
|
||||||
|
self.export_graph_with_main_op(builder_cls)
|
||||||
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
saver, _ = loader.load_graph(
|
saver, _ = loader.load_graph(
|
||||||
@ -119,6 +139,12 @@ class SavedModelLoaderTest(test.TestCase):
|
|||||||
loader.restore_variables(sess, tf_saver.Saver())
|
loader.restore_variables(sess, tf_saver.Saver())
|
||||||
|
|
||||||
loader.restore_variables(sess, saver)
|
loader.restore_variables(sess, saver)
|
||||||
|
|
||||||
|
if builder_cls == saved_model_builder._SavedModelBuilder:
|
||||||
|
with self.assertRaises(errors.NotFoundError):
|
||||||
|
loader.run_init_ops(sess, ["foo_graph"])
|
||||||
|
loader.run_init_ops(sess, ["foo_graph"], import_scope="baz")
|
||||||
|
else:
|
||||||
loader.run_init_ops(sess, ["foo_graph"])
|
loader.run_init_ops(sess, ["foo_graph"])
|
||||||
|
|
||||||
self.assertEqual(5, sess.graph.get_tensor_by_name("baz/x:0").eval())
|
self.assertEqual(5, sess.graph.get_tensor_by_name("baz/x:0").eval())
|
||||||
@ -131,7 +157,8 @@ class SavedModelLoaderTest(test.TestCase):
|
|||||||
self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval())
|
self.assertEqual(5, sess.graph.get_tensor_by_name("baa/x:0").eval())
|
||||||
self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval())
|
self.assertEqual(7, sess.graph.get_tensor_by_name("baa/y:0").eval())
|
||||||
|
|
||||||
def test_restore_variables(self):
|
def test_restore_variables(self, builder_cls):
|
||||||
|
self.export_graph_with_main_op(builder_cls)
|
||||||
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
x = variables.VariableV1(0, name="x")
|
x = variables.VariableV1(0, name="x")
|
||||||
@ -147,7 +174,8 @@ class SavedModelLoaderTest(test.TestCase):
|
|||||||
loader.restore_variables(sess, tf_saver.Saver())
|
loader.restore_variables(sess, tf_saver.Saver())
|
||||||
self.assertEqual(55, self.evaluate(z))
|
self.assertEqual(55, self.evaluate(z))
|
||||||
|
|
||||||
def test_run_init_op(self):
|
def test_run_init_op(self, builder_cls):
|
||||||
|
self.export_graph_with_main_op(builder_cls)
|
||||||
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
saver, _ = loader.load_graph(graph, ["foo_graph"])
|
saver, _ = loader.load_graph(graph, ["foo_graph"])
|
||||||
@ -160,14 +188,16 @@ class SavedModelLoaderTest(test.TestCase):
|
|||||||
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
||||||
self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
|
self.assertEqual(7, sess.graph.get_tensor_by_name("y:0").eval())
|
||||||
|
|
||||||
def test_parse_saved_model(self):
|
def test_parse_saved_model(self, builder_cls):
|
||||||
|
self.export_simple_graph(builder_cls)
|
||||||
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
||||||
meta_graph = loader.get_meta_graph_def_from_tags(["foo_graph"])
|
meta_graph = loader.get_meta_graph_def_from_tags(["foo_graph"])
|
||||||
self.assertIsNotNone(meta_graph)
|
self.assertIsNotNone(meta_graph)
|
||||||
self.assertIn("foo", meta_graph.signature_def)
|
self.assertIn("foo", meta_graph.signature_def)
|
||||||
self.assertIn("bar", meta_graph.signature_def)
|
self.assertIn("bar", meta_graph.signature_def)
|
||||||
|
|
||||||
def test_load_invalid_meta_graph(self):
|
def test_load_invalid_meta_graph(self, builder_cls):
|
||||||
|
self.export_simple_graph(builder_cls)
|
||||||
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
loader.get_meta_graph_def_from_tags([])
|
loader.get_meta_graph_def_from_tags([])
|
||||||
@ -176,13 +206,16 @@ class SavedModelLoaderTest(test.TestCase):
|
|||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
loader.get_meta_graph_def_from_tags(["not_a_graph"])
|
loader.get_meta_graph_def_from_tags(["not_a_graph"])
|
||||||
|
|
||||||
def test_load_saved_model_with_no_variables(self):
|
def test_load_saved_model_with_no_variables(self, builder_cls):
|
||||||
"""Test that SavedModel runs saver when there appear to be no variables.
|
"""Test that SavedModel runs saver when there appear to be no variables.
|
||||||
|
|
||||||
When no variables are detected, this may mean that the variables were saved
|
When no variables are detected, this may mean that the variables were saved
|
||||||
to different collections, or the collections weren't saved to the
|
to different collections, or the collections weren't saved to the
|
||||||
SavedModel. If the SavedModel MetaGraphDef contains a saver, it should still
|
SavedModel. If the SavedModel MetaGraphDef contains a saver, it should still
|
||||||
run in either of these cases.
|
run in either of these cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
builder_cls: SavedModelBuilder or _SavedModelBuilder class
|
||||||
"""
|
"""
|
||||||
path = _get_export_dir("no_variable_saved_model")
|
path = _get_export_dir("no_variable_saved_model")
|
||||||
with session.Session(graph=ops.Graph()) as sess:
|
with session.Session(graph=ops.Graph()) as sess:
|
||||||
@ -192,7 +225,7 @@ class SavedModelLoaderTest(test.TestCase):
|
|||||||
11, name="y", collections=["not_global_variable"])
|
11, name="y", collections=["not_global_variable"])
|
||||||
self.assertFalse(variables._all_saveable_objects())
|
self.assertFalse(variables._all_saveable_objects())
|
||||||
z = x + y
|
z = x + y
|
||||||
sess.run(variables.variables_initializer([x, y]))
|
self.evaluate(variables.variables_initializer([x, y]))
|
||||||
|
|
||||||
foo_sig_def = signature_def_utils.build_signature_def(
|
foo_sig_def = signature_def_utils.build_signature_def(
|
||||||
{"foo_input": utils.build_tensor_info(x)},
|
{"foo_input": utils.build_tensor_info(x)},
|
||||||
@ -215,8 +248,9 @@ class SavedModelLoaderTest(test.TestCase):
|
|||||||
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
|
||||||
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
|
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
|
||||||
|
|
||||||
def test_load_saved_model_graph_with_return_elements(self):
|
def test_load_saved_model_graph_with_return_elements(self, builder_cls):
|
||||||
"""Ensure that the correct elements are returned."""
|
"""Ensure that the correct elements are returned."""
|
||||||
|
self.export_simple_graph(builder_cls)
|
||||||
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
loader = loader_impl.SavedModelLoader(SIMPLE_ADD_SAVED_MODEL)
|
||||||
graph = ops.Graph()
|
graph = ops.Graph()
|
||||||
_, ret = loader.load_graph(graph, ["foo_graph"],
|
_, ret = loader.load_graph(graph, ["foo_graph"],
|
||||||
@ -228,5 +262,6 @@ class SavedModelLoaderTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(ValueError, "not found in graph"):
|
with self.assertRaisesRegexp(ValueError, "not found in graph"):
|
||||||
loader.load_graph(graph, ["foo_graph"], return_elements=["z:0"])
|
loader.load_graph(graph, ["foo_graph"], return_elements=["z:0"])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -146,6 +146,18 @@ class SavedModelTest(SavedModelTestBase):
|
|||||||
sess, ["foo"],
|
sess, ["foo"],
|
||||||
signature_def_map={"foo_key": foo_signature})
|
signature_def_map={"foo_key": foo_signature})
|
||||||
|
|
||||||
|
def _validate_sig_def_keys(self, builder, valid_tensor_info, invalid_key):
|
||||||
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
|
|
||||||
|
foo_signature = signature_def_utils.build_signature_def(
|
||||||
|
dict(), {"foo_key": valid_tensor_info}, "foo")
|
||||||
|
self.assertRaises(
|
||||||
|
KeyError,
|
||||||
|
builder.add_meta_graph_and_variables,
|
||||||
|
sess, ["foo"],
|
||||||
|
signature_def_map={invalid_key: foo_signature})
|
||||||
|
|
||||||
def testMaybeSavedModelDir(self):
|
def testMaybeSavedModelDir(self):
|
||||||
base_path = test.test_src_dir_path("/python/saved_model")
|
base_path = test.test_src_dir_path("/python/saved_model")
|
||||||
self.assertFalse(loader.maybe_saved_model_directory(base_path))
|
self.assertFalse(loader.maybe_saved_model_directory(base_path))
|
||||||
@ -583,6 +595,15 @@ class SavedModelTest(SavedModelTestBase):
|
|||||||
self._validate_inputs_tensor_info_fail(builder, tensor_empty)
|
self._validate_inputs_tensor_info_fail(builder, tensor_empty)
|
||||||
self._validate_outputs_tensor_info_fail(builder, tensor_empty)
|
self._validate_outputs_tensor_info_fail(builder, tensor_empty)
|
||||||
|
|
||||||
|
valid_tensor_info = meta_graph_pb2.TensorInfo()
|
||||||
|
valid_tensor_info.name = "foo"
|
||||||
|
valid_tensor_info.dtype = types_pb2.DT_FLOAT
|
||||||
|
|
||||||
|
self._validate_sig_def_keys(builder, valid_tensor_info,
|
||||||
|
constants.INIT_OP_SIGNATURE_KEY)
|
||||||
|
self._validate_sig_def_keys(builder, valid_tensor_info,
|
||||||
|
constants.TRAIN_OP_SIGNATURE_KEY)
|
||||||
|
|
||||||
def testSignatureDefValidationSucceedsWithName(self):
|
def testSignatureDefValidationSucceedsWithName(self):
|
||||||
tensor_with_name = meta_graph_pb2.TensorInfo()
|
tensor_with_name = meta_graph_pb2.TensorInfo()
|
||||||
tensor_with_name.name = "foo"
|
tensor_with_name.name = "foo"
|
||||||
@ -782,7 +803,7 @@ class SavedModelTest(SavedModelTestBase):
|
|||||||
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
self._validate_assets(export_dir, foo_graph.asset_file_def, "hello42.txt",
|
||||||
"foo bar baz 0", "asset_file_tensor_0:0")
|
"foo bar baz 0", "asset_file_tensor_0:0")
|
||||||
|
|
||||||
def testCustomMainOp(self):
|
def testCustomInitOp(self):
|
||||||
export_dir = self._get_export_dir("test_main_op")
|
export_dir = self._get_export_dir("test_main_op")
|
||||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
@ -800,11 +821,11 @@ class SavedModelTest(SavedModelTestBase):
|
|||||||
# Set up an assignment op to be run as part of the main_op.
|
# Set up an assignment op to be run as part of the main_op.
|
||||||
with ops.control_dependencies([main_op.main_op()]):
|
with ops.control_dependencies([main_op.main_op()]):
|
||||||
add_v1_v2 = math_ops.add(v1._ref(), v2._ref())
|
add_v1_v2 = math_ops.add(v1._ref(), v2._ref())
|
||||||
custom_main_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))
|
custom_init_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))
|
||||||
|
|
||||||
self.evaluate(custom_main_op)
|
self.evaluate(custom_init_op)
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
sess, ["foo"], main_op=custom_main_op)
|
sess, ["foo"], init_op=custom_init_op)
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
# Save the SavedModel to disk.
|
||||||
builder.save()
|
builder.save()
|
||||||
@ -817,80 +838,6 @@ class SavedModelTest(SavedModelTestBase):
|
|||||||
# the main_op, following a restore.
|
# the main_op, following a restore.
|
||||||
self.assertEqual(3, ops.get_collection("v")[2].eval())
|
self.assertEqual(3, ops.get_collection("v")[2].eval())
|
||||||
|
|
||||||
def testLegacyInitOp(self):
|
|
||||||
export_dir = self._get_export_dir("test_legacy_init_op")
|
|
||||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
|
||||||
# Add `v1` and `v2` variables to the graph.
|
|
||||||
v1 = variables.VariableV1(1, name="v1")
|
|
||||||
ops.add_to_collection("v", v1)
|
|
||||||
v2 = variables.VariableV1(2, name="v2")
|
|
||||||
ops.add_to_collection("v", v2)
|
|
||||||
|
|
||||||
# Initialize another variable `v3` to 42.
|
|
||||||
v3 = variables.VariableV1(42, name="v3", trainable=False, collections=[])
|
|
||||||
ops.add_to_collection("v", v3)
|
|
||||||
|
|
||||||
# Set up an assignment op to be run as part of the legacy_init_op.
|
|
||||||
assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
|
|
||||||
legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")
|
|
||||||
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
builder.add_meta_graph_and_variables(
|
|
||||||
sess, ["foo"], legacy_init_op=legacy_init_op)
|
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
|
||||||
builder.save()
|
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
|
||||||
loader.load(sess, ["foo"], export_dir)
|
|
||||||
self.assertEqual(1, ops.get_collection("v")[0].eval())
|
|
||||||
self.assertEqual(2, ops.get_collection("v")[1].eval())
|
|
||||||
# Evaluates to the sum of the first two variables and assigned as part of
|
|
||||||
# the legacy_init_op, following a restore.
|
|
||||||
self.assertEqual(3, ops.get_collection("v")[2].eval())
|
|
||||||
|
|
||||||
def testLegacyInitOpWithNonEmptyCollection(self):
|
|
||||||
export_dir = self._get_export_dir(
|
|
||||||
"test_legacy_init_op_with_non_empty_collection")
|
|
||||||
self._testInitOpsWithNonEmptyCollection(
|
|
||||||
export_dir, constants.LEGACY_INIT_OP_KEY)
|
|
||||||
|
|
||||||
def testMainOpWithNonEmptyCollection(self):
|
|
||||||
export_dir = self._get_export_dir(
|
|
||||||
"test_main_op_with_non_empty_collection")
|
|
||||||
self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY)
|
|
||||||
|
|
||||||
def _testInitOpsWithNonEmptyCollection(self, export_dir, key):
|
|
||||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
|
||||||
|
|
||||||
g = ops.Graph()
|
|
||||||
with self.session(graph=g) as sess:
|
|
||||||
# Initialize variable `v1` to 1.
|
|
||||||
v1 = variables.VariableV1(1, name="v1")
|
|
||||||
ops.add_to_collection("v", v1)
|
|
||||||
|
|
||||||
# Initialize another variable `v2` to 42.
|
|
||||||
v2 = variables.VariableV1(42, name="v2", trainable=False, collections=[])
|
|
||||||
ops.add_to_collection("v", v2)
|
|
||||||
|
|
||||||
# Set up an assignment op to be run as part of the init op.
|
|
||||||
assign_v2 = state_ops.assign(v2, v1)
|
|
||||||
init_op = control_flow_ops.group(assign_v2, name="init_op")
|
|
||||||
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
|
|
||||||
ops.add_to_collection(key, control_flow_ops.no_op())
|
|
||||||
# ValueError should be raised since the LEGACY_INIT_OP_KEY collection
|
|
||||||
# is not empty and we don't support multiple init ops.
|
|
||||||
with self.assertRaisesRegexp(ValueError, "Graph already contains"):
|
|
||||||
builder.add_meta_graph_and_variables(
|
|
||||||
sess, ["foo"], legacy_init_op=init_op)
|
|
||||||
# We shouldn't be able to add as MAIN_OP, either.
|
|
||||||
with self.assertRaisesRegexp(ValueError, "Graph already contains"):
|
|
||||||
builder.add_meta_graph_and_variables(sess, ["foo"], main_op=init_op)
|
|
||||||
|
|
||||||
def testTrainOp(self):
|
def testTrainOp(self):
|
||||||
export_dir = self._get_export_dir("test_train_op")
|
export_dir = self._get_export_dir("test_train_op")
|
||||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
||||||
@ -906,19 +853,17 @@ class SavedModelTest(SavedModelTestBase):
|
|||||||
train_op = state_ops.assign_add(v1, v2)
|
train_op = state_ops.assign_add(v1, v2)
|
||||||
|
|
||||||
self.evaluate(train_op)
|
self.evaluate(train_op)
|
||||||
# TODO(karmel): remove explicit call when in the public method.
|
builder.add_meta_graph_and_variables(sess, ["foo"], train_op=train_op)
|
||||||
builder._add_train_op(train_op)
|
|
||||||
builder.add_meta_graph_and_variables(sess, ["foo"])
|
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
# Save the SavedModel to disk.
|
||||||
builder.save()
|
builder.save()
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
loader.load(sess, ["foo"], export_dir)
|
meta_graph_def = loader.load(sess, ["foo"], export_dir)
|
||||||
self.assertEqual(3, ops.get_collection("v")[0].eval())
|
self.assertEqual(3, ops.get_collection("v")[0].eval())
|
||||||
self.assertEqual(2, ops.get_collection("v")[1].eval())
|
self.assertEqual(2, ops.get_collection("v")[1].eval())
|
||||||
self.assertIsInstance(
|
self.assertIsInstance(
|
||||||
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
|
loader_impl.get_train_op(meta_graph_def), ops.Tensor)
|
||||||
|
|
||||||
def testTrainOpGroup(self):
|
def testTrainOpGroup(self):
|
||||||
export_dir = self._get_export_dir("test_train_op_group")
|
export_dir = self._get_export_dir("test_train_op_group")
|
||||||
@ -935,19 +880,17 @@ class SavedModelTest(SavedModelTestBase):
|
|||||||
train_op = control_flow_ops.group()
|
train_op = control_flow_ops.group()
|
||||||
|
|
||||||
self.evaluate(train_op)
|
self.evaluate(train_op)
|
||||||
# TODO(karmel): remove explicit call when in the public method.
|
builder.add_meta_graph_and_variables(sess, ["foo"], train_op=train_op)
|
||||||
builder._add_train_op(train_op)
|
|
||||||
builder.add_meta_graph_and_variables(sess, ["foo"])
|
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
# Save the SavedModel to disk.
|
||||||
builder.save()
|
builder.save()
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
loader.load(sess, ["foo"], export_dir)
|
meta_graph_def = loader.load(sess, ["foo"], export_dir)
|
||||||
self.assertEqual(1, ops.get_collection("v")[0].eval())
|
self.assertEqual(1, ops.get_collection("v")[0].eval())
|
||||||
self.assertEqual(2, ops.get_collection("v")[1].eval())
|
self.assertEqual(2, ops.get_collection("v")[1].eval())
|
||||||
self.assertIsInstance(
|
self.assertIsInstance(
|
||||||
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation)
|
loader_impl.get_train_op(meta_graph_def), ops.Operation)
|
||||||
|
|
||||||
def testTrainOpAfterVariables(self):
|
def testTrainOpAfterVariables(self):
|
||||||
export_dir = self._get_export_dir("test_train_op_after_variables")
|
export_dir = self._get_export_dir("test_train_op_after_variables")
|
||||||
@ -965,17 +908,15 @@ class SavedModelTest(SavedModelTestBase):
|
|||||||
|
|
||||||
train_op = state_ops.assign_add(v1, v2)
|
train_op = state_ops.assign_add(v1, v2)
|
||||||
self.evaluate(train_op)
|
self.evaluate(train_op)
|
||||||
# TODO(karmel): remove explicit call when in the public method.
|
builder.add_meta_graph(["foo"], train_op=train_op)
|
||||||
builder._add_train_op(train_op)
|
|
||||||
builder.add_meta_graph(["foo"])
|
|
||||||
|
|
||||||
# Save the SavedModel to disk.
|
# Save the SavedModel to disk.
|
||||||
builder.save()
|
builder.save()
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
loader.load(sess, ["foo"], export_dir)
|
meta_graph_def = loader.load(sess, ["foo"], export_dir)
|
||||||
self.assertIsInstance(
|
self.assertIsInstance(
|
||||||
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
|
loader_impl.get_train_op(meta_graph_def), ops.Tensor)
|
||||||
|
|
||||||
with self.session(graph=ops.Graph()) as sess:
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
loader.load(sess, ["pre_foo"], export_dir)
|
loader.load(sess, ["pre_foo"], export_dir)
|
||||||
@ -1288,76 +1229,6 @@ class SavedModelTest(SavedModelTestBase):
|
|||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
|
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
|
||||||
|
|
||||||
def testStripDefaultAttrs(self):
|
|
||||||
export_dir = self._get_export_dir("test_strip_default_attrs")
|
|
||||||
builder = saved_model_builder._SavedModelBuilder(export_dir)
|
|
||||||
|
|
||||||
# Add a graph with two float32 variables and a Complex Op composing them
|
|
||||||
# with strip_default_attrs enabled.
|
|
||||||
with session.Session(graph=ops.Graph()) as sess:
|
|
||||||
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
|
|
||||||
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
|
|
||||||
math_ops.complex(real_num, imag_num, name="complex")
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
builder.add_meta_graph_and_variables(
|
|
||||||
sess, ["foo"], strip_default_attrs=True)
|
|
||||||
|
|
||||||
# Add a graph with the same float32 variables and a Complex Op composing
|
|
||||||
# them with strip_default_attrs disabled.
|
|
||||||
with session.Session(graph=ops.Graph()) as sess:
|
|
||||||
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
|
|
||||||
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
|
|
||||||
math_ops.complex(real_num, imag_num, name="complex")
|
|
||||||
self.evaluate(variables.global_variables_initializer())
|
|
||||||
builder.add_meta_graph(["bar"], strip_default_attrs=False)
|
|
||||||
|
|
||||||
# Save the SavedModel to disk in text format.
|
|
||||||
builder.save(as_text=True)
|
|
||||||
|
|
||||||
# Loading graph "foo" via the loader must restore the defaults for the
|
|
||||||
# "Complex" node based on the "Complex" OpDef in the Op registry.
|
|
||||||
sess = session.Session(graph=ops.Graph())
|
|
||||||
meta_graph_def = loader.load(sess, ["foo"], export_dir)
|
|
||||||
complex_node = test_util.get_node_def_from_graph("complex",
|
|
||||||
meta_graph_def.graph_def)
|
|
||||||
self.assertIn("T", complex_node.attr)
|
|
||||||
self.assertIn("Tout", complex_node.attr)
|
|
||||||
|
|
||||||
# Load graph "foo" from disk as-is to verify default attrs are stripped.
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
saved_model_pb = loader_impl._parse_saved_model(export_dir)
|
|
||||||
self.assertIsNotNone(saved_model_pb)
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
meta_graph_foo_def = None
|
|
||||||
meta_graph_bar_def = None
|
|
||||||
for meta_graph_def in saved_model_pb.meta_graphs:
|
|
||||||
if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
|
|
||||||
meta_graph_foo_def = meta_graph_def
|
|
||||||
elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
|
|
||||||
meta_graph_bar_def = meta_graph_def
|
|
||||||
|
|
||||||
self.assertIsNotNone(meta_graph_foo_def)
|
|
||||||
self.assertIsNotNone(meta_graph_bar_def)
|
|
||||||
|
|
||||||
# "Complex" Op has 2 attributes with defaults:
|
|
||||||
# o "T" : float32. (input type)
|
|
||||||
# o "Tout" : complex64. (output type)
|
|
||||||
|
|
||||||
# "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
|
|
||||||
# Graph "foo" was saved with strip_default_attrs set to True.
|
|
||||||
node_def = test_util.get_node_def_from_graph("complex",
|
|
||||||
meta_graph_foo_def.graph_def)
|
|
||||||
self.assertNotIn("T", node_def.attr)
|
|
||||||
self.assertNotIn("Tout", node_def.attr)
|
|
||||||
|
|
||||||
# "Complex" Op in graph "bar" must have attributes "T" and "Tout".
|
|
||||||
# Graph "bar" was saved with strip_default_attrs set to False.
|
|
||||||
node_def = test_util.get_node_def_from_graph("complex",
|
|
||||||
meta_graph_bar_def.graph_def)
|
|
||||||
self.assertIn("T", node_def.attr)
|
|
||||||
self.assertIn("Tout", node_def.attr)
|
|
||||||
|
|
||||||
# Tests the behavior of loading SavedModels that having missing attrs or attrs
|
# Tests the behavior of loading SavedModels that having missing attrs or attrs
|
||||||
# with incorrect types.
|
# with incorrect types.
|
||||||
def testInconsistentConsumerDefaultAttrs(self):
|
def testInconsistentConsumerDefaultAttrs(self):
|
||||||
@ -1484,6 +1355,149 @@ class SavedModelV1Test(SavedModelTestBase):
|
|||||||
compat.as_bytes("ignored.txt"))
|
compat.as_bytes("ignored.txt"))
|
||||||
self.assertFalse(file_io.file_exists(ignored_asset_path))
|
self.assertFalse(file_io.file_exists(ignored_asset_path))
|
||||||
|
|
||||||
|
def testLegacyInitOpWithNonEmptyCollection(self):
|
||||||
|
export_dir = self._get_export_dir(
|
||||||
|
"test_legacy_init_op_with_non_empty_collection")
|
||||||
|
self._testInitOpsWithNonEmptyCollection(export_dir,
|
||||||
|
constants.LEGACY_INIT_OP_KEY)
|
||||||
|
|
||||||
|
def testMainOpWithNonEmptyCollection(self):
|
||||||
|
export_dir = self._get_export_dir("test_main_op_with_non_empty_collection")
|
||||||
|
self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY)
|
||||||
|
|
||||||
|
def _testInitOpsWithNonEmptyCollection(self, export_dir, key):
|
||||||
|
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
|
g = ops.Graph()
|
||||||
|
with self.session(graph=g) as sess:
|
||||||
|
# Initialize variable `v1` to 1.
|
||||||
|
v1 = variables.VariableV1(1, name="v1")
|
||||||
|
ops.add_to_collection("v", v1)
|
||||||
|
|
||||||
|
# Initialize another variable `v2` to 42.
|
||||||
|
v2 = variables.VariableV1(42, name="v2", trainable=False, collections=[])
|
||||||
|
ops.add_to_collection("v", v2)
|
||||||
|
|
||||||
|
# Set up an assignment op to be run as part of the init op.
|
||||||
|
assign_v2 = state_ops.assign(v2, v1)
|
||||||
|
init_op = control_flow_ops.group(assign_v2, name="init_op")
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
ops.add_to_collection(key, control_flow_ops.no_op())
|
||||||
|
# ValueError should be raised since the LEGACY_INIT_OP_KEY collection
|
||||||
|
# is not empty and we don't support multiple init ops.
|
||||||
|
with self.assertRaisesRegexp(ValueError, "Graph already contains"):
|
||||||
|
builder.add_meta_graph_and_variables(
|
||||||
|
sess, ["foo"], legacy_init_op=init_op)
|
||||||
|
# We shouldn't be able to add as MAIN_OP, either.
|
||||||
|
with self.assertRaisesRegexp(ValueError, "Graph already contains"):
|
||||||
|
builder.add_meta_graph_and_variables(sess, ["foo"], main_op=init_op)
|
||||||
|
|
||||||
|
def testStripDefaultAttrs(self):
|
||||||
|
export_dir = self._get_export_dir("test_strip_default_attrs")
|
||||||
|
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
|
# Add a graph with two float32 variables and a Complex Op composing them
|
||||||
|
# with strip_default_attrs enabled.
|
||||||
|
with session.Session(graph=ops.Graph()) as sess:
|
||||||
|
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
|
||||||
|
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
|
||||||
|
math_ops.complex(real_num, imag_num, name="complex")
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
builder.add_meta_graph_and_variables(
|
||||||
|
sess, ["foo"], strip_default_attrs=True)
|
||||||
|
|
||||||
|
# Add a graph with the same float32 variables and a Complex Op composing
|
||||||
|
# them with strip_default_attrs disabled.
|
||||||
|
with session.Session(graph=ops.Graph()) as sess:
|
||||||
|
real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
|
||||||
|
imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
|
||||||
|
math_ops.complex(real_num, imag_num, name="complex")
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
builder.add_meta_graph(["bar"], strip_default_attrs=False)
|
||||||
|
|
||||||
|
# Save the SavedModel to disk in text format.
|
||||||
|
builder.save(as_text=True)
|
||||||
|
|
||||||
|
# Loading graph "foo" via the loader must restore the defaults for the
|
||||||
|
# "Complex" node based on the "Complex" OpDef in the Op registry.
|
||||||
|
sess = session.Session(graph=ops.Graph())
|
||||||
|
meta_graph_def = loader.load(sess, ["foo"], export_dir)
|
||||||
|
complex_node = test_util.get_node_def_from_graph("complex",
|
||||||
|
meta_graph_def.graph_def)
|
||||||
|
self.assertIn("T", complex_node.attr)
|
||||||
|
self.assertIn("Tout", complex_node.attr)
|
||||||
|
|
||||||
|
# Load graph "foo" from disk as-is to verify default attrs are stripped.
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
saved_model_pb = loader_impl._parse_saved_model(export_dir)
|
||||||
|
self.assertIsNotNone(saved_model_pb)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
meta_graph_foo_def = None
|
||||||
|
meta_graph_bar_def = None
|
||||||
|
for meta_graph_def in saved_model_pb.meta_graphs:
|
||||||
|
if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
|
||||||
|
meta_graph_foo_def = meta_graph_def
|
||||||
|
elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
|
||||||
|
meta_graph_bar_def = meta_graph_def
|
||||||
|
|
||||||
|
self.assertIsNotNone(meta_graph_foo_def)
|
||||||
|
self.assertIsNotNone(meta_graph_bar_def)
|
||||||
|
|
||||||
|
# "Complex" Op has 2 attributes with defaults:
|
||||||
|
# o "T" : float32. (input type)
|
||||||
|
# o "Tout" : complex64. (output type)
|
||||||
|
|
||||||
|
# "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
|
||||||
|
# Graph "foo" was saved with strip_default_attrs set to True.
|
||||||
|
node_def = test_util.get_node_def_from_graph("complex",
|
||||||
|
meta_graph_foo_def.graph_def)
|
||||||
|
self.assertNotIn("T", node_def.attr)
|
||||||
|
self.assertNotIn("Tout", node_def.attr)
|
||||||
|
|
||||||
|
# "Complex" Op in graph "bar" must have attributes "T" and "Tout".
|
||||||
|
# Graph "bar" was saved with strip_default_attrs set to False.
|
||||||
|
node_def = test_util.get_node_def_from_graph("complex",
|
||||||
|
meta_graph_bar_def.graph_def)
|
||||||
|
self.assertIn("T", node_def.attr)
|
||||||
|
self.assertIn("Tout", node_def.attr)
|
||||||
|
|
||||||
|
def testLegacyInitOp(self):
|
||||||
|
export_dir = self._get_export_dir("test_legacy_init_op")
|
||||||
|
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
|
# Add `v1` and `v2` variables to the graph.
|
||||||
|
v1 = variables.VariableV1(1, name="v1")
|
||||||
|
ops.add_to_collection("v", v1)
|
||||||
|
v2 = variables.VariableV1(2, name="v2")
|
||||||
|
ops.add_to_collection("v", v2)
|
||||||
|
|
||||||
|
# Initialize another variable `v3` to 42.
|
||||||
|
v3 = variables.VariableV1(42, name="v3", trainable=False, collections=[])
|
||||||
|
ops.add_to_collection("v", v3)
|
||||||
|
|
||||||
|
# Set up an assignment op to be run as part of the init_op.
|
||||||
|
assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
|
||||||
|
legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")
|
||||||
|
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
builder.add_meta_graph_and_variables(
|
||||||
|
sess, ["foo"], legacy_init_op=legacy_init_op)
|
||||||
|
|
||||||
|
# Save the SavedModel to disk.
|
||||||
|
builder.save()
|
||||||
|
|
||||||
|
with self.session(graph=ops.Graph()) as sess:
|
||||||
|
loader.load(sess, ["foo"], export_dir)
|
||||||
|
self.assertEqual(1, ops.get_collection("v")[0].eval())
|
||||||
|
self.assertEqual(2, ops.get_collection("v")[1].eval())
|
||||||
|
# Evaluates to the sum of the first two variables and assigned as part of
|
||||||
|
# the legacy_init_op, following a restore.
|
||||||
|
self.assertEqual(3, ops.get_collection("v")[2].eval())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -24,6 +24,8 @@ from __future__ import print_function
|
|||||||
from tensorflow.python.saved_model.signature_def_utils_impl import build_signature_def
|
from tensorflow.python.saved_model.signature_def_utils_impl import build_signature_def
|
||||||
from tensorflow.python.saved_model.signature_def_utils_impl import classification_signature_def
|
from tensorflow.python.saved_model.signature_def_utils_impl import classification_signature_def
|
||||||
from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature
|
from tensorflow.python.saved_model.signature_def_utils_impl import is_valid_signature
|
||||||
|
from tensorflow.python.saved_model.signature_def_utils_impl import load_op_from_signature_def
|
||||||
|
from tensorflow.python.saved_model.signature_def_utils_impl import op_signature_def
|
||||||
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
|
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
|
||||||
from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def
|
from tensorflow.python.saved_model.signature_def_utils_impl import regression_signature_def
|
||||||
from tensorflow.python.saved_model.signature_def_utils_impl import supervised_eval_signature_def
|
from tensorflow.python.saved_model.signature_def_utils_impl import supervised_eval_signature_def
|
||||||
|
@ -21,9 +21,10 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.core.framework import types_pb2
|
from tensorflow.core.framework import types_pb2
|
||||||
from tensorflow.core.protobuf import meta_graph_pb2
|
from tensorflow.core.protobuf import meta_graph_pb2
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.saved_model import signature_constants
|
from tensorflow.python.saved_model import signature_constants
|
||||||
from tensorflow.python.saved_model import utils
|
from tensorflow.python.saved_model import utils_impl as utils
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
@ -349,3 +350,51 @@ def _is_valid_classification_signature(signature_def):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def op_signature_def(op, key):
|
||||||
|
"""Creates a signature def with the output pointing to an op.
|
||||||
|
|
||||||
|
Note that op isn't strictly enforced to be an Op object, and may be a Tensor.
|
||||||
|
It is recommended to use the build_signature_def() function for Tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
op: An Op (or possibly Tensor).
|
||||||
|
key: Key to graph element in the SignatureDef outputs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A SignatureDef with a single output pointing to the op.
|
||||||
|
"""
|
||||||
|
# Use build_tensor_info_from_op, which creates a TensorInfo from the element's
|
||||||
|
# name.
|
||||||
|
return build_signature_def(outputs={key: utils.build_tensor_info_from_op(op)})
|
||||||
|
|
||||||
|
|
||||||
|
def load_op_from_signature_def(signature_def, key, import_scope=None):
|
||||||
|
"""Load an Op from a SignatureDef created by op_signature_def().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
signature_def: a SignatureDef proto
|
||||||
|
key: string key to op in the SignatureDef outputs.
|
||||||
|
import_scope: Scope used to import the op
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Op (or possibly Tensor) in the graph with the same name as saved in the
|
||||||
|
SignatureDef.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the op could not be found in the graph.
|
||||||
|
"""
|
||||||
|
tensor_info = signature_def.outputs[key]
|
||||||
|
try:
|
||||||
|
# The init and train ops are not strictly enforced to be operations, so
|
||||||
|
# retrieve any graph element (can be either op or tensor).
|
||||||
|
return utils.get_element_from_tensor_info(
|
||||||
|
tensor_info, import_scope=import_scope)
|
||||||
|
except KeyError:
|
||||||
|
raise errors.NotFoundError(
|
||||||
|
None, None,
|
||||||
|
'The {0} could not be found in the graph. Please make sure the '
|
||||||
|
'SavedModel was created by the internal _SavedModelBuilder. If you '
|
||||||
|
'are using the public API, please make sure the SignatureDef in the '
|
||||||
|
'SavedModel does not contain the key "{0}".'.format(key))
|
||||||
|
@ -23,6 +23,7 @@ from tensorflow.core.protobuf import meta_graph_pb2
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.saved_model import signature_constants
|
from tensorflow.python.saved_model import signature_constants
|
||||||
from tensorflow.python.saved_model import signature_def_utils_impl
|
from tensorflow.python.saved_model import signature_def_utils_impl
|
||||||
@ -413,5 +414,22 @@ class SignatureDefUtilsTest(test.TestCase):
|
|||||||
{},
|
{},
|
||||||
signature_constants.PREDICT_METHOD_NAME)
|
signature_constants.PREDICT_METHOD_NAME)
|
||||||
|
|
||||||
|
def testOpSignatureDef(self):
|
||||||
|
key = "adding_1_and_2_key"
|
||||||
|
add_op = math_ops.add(1, 2, name="adding_1_and_2")
|
||||||
|
signature_def = signature_def_utils_impl.op_signature_def(add_op, key)
|
||||||
|
self.assertIn(key, signature_def.outputs)
|
||||||
|
self.assertEqual(add_op.name, signature_def.outputs[key].name)
|
||||||
|
|
||||||
|
def testLoadOpFromSignatureDef(self):
|
||||||
|
key = "adding_1_and_2_key"
|
||||||
|
add_op = math_ops.add(1, 2, name="adding_1_and_2")
|
||||||
|
signature_def = signature_def_utils_impl.op_signature_def(add_op, key)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
add_op,
|
||||||
|
signature_def_utils_impl.load_op_from_signature_def(signature_def, key))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -141,6 +141,27 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
|
|||||||
raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
|
raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
|
||||||
|
|
||||||
|
|
||||||
|
def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None):
|
||||||
|
"""Returns the element in the graph described by a TensorInfo proto.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor_info: A TensorInfo proto describing an Op or Tensor by name.
|
||||||
|
graph: The tf.Graph in which tensors are looked up. If None, the current
|
||||||
|
default graph is used.
|
||||||
|
import_scope: If not None, names in `tensor_info` are prefixed with this
|
||||||
|
string before lookup.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Op or tensor in `graph` described by `tensor_info`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: If `tensor_info` does not correspond to an op or tensor in `graph`
|
||||||
|
"""
|
||||||
|
graph = graph or ops.get_default_graph()
|
||||||
|
return graph.as_graph_element(
|
||||||
|
ops.prepend_name_scope(tensor_info.name, import_scope=import_scope))
|
||||||
|
|
||||||
|
|
||||||
# Path helpers.
|
# Path helpers.
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,10 +32,6 @@ tf_module {
|
|||||||
name: "GPU"
|
name: "GPU"
|
||||||
mtype: "<type \'str\'>"
|
mtype: "<type \'str\'>"
|
||||||
}
|
}
|
||||||
member {
|
|
||||||
name: "MAIN_OP_KEY"
|
|
||||||
mtype: "<type \'str\'>"
|
|
||||||
}
|
|
||||||
member {
|
member {
|
||||||
name: "PREDICT_INPUTS"
|
name: "PREDICT_INPUTS"
|
||||||
mtype: "<type \'str\'>"
|
mtype: "<type \'str\'>"
|
||||||
|
Loading…
Reference in New Issue
Block a user