Return model format from LoadSessionBundleOrSavedModelBundle(),

allowing callers to know if we up-converted a SessionBundle to
SavedModel format.

PiperOrigin-RevId: 213937542
This commit is contained in:
Abhijit Karmarkar 2018-09-20 22:18:35 -07:00 committed by TensorFlower Gardener
parent f10b00558d
commit 23552a8b2f
3 changed files with 22 additions and 7 deletions

View File

@ -355,11 +355,15 @@ Status LoadSessionBundleOrSavedModelBundle(
const SessionOptions& session_options, const RunOptions& run_options,
const string& export_dir,
const std::unordered_set<string>& saved_model_tags,
SavedModelBundle* saved_model_bundle) {
SavedModelBundle* saved_model_bundle, bool* is_session_bundle) {
if (is_session_bundle != nullptr) {
*is_session_bundle = false;
}
if (MaybeSavedModelDirectory(export_dir)) {
LOG(INFO)
<< "Attempting to load native SavedModelBundle in bundle-shim from: "
<< export_dir;
return LoadSavedModel(session_options, run_options, export_dir,
saved_model_tags, saved_model_bundle);
} else if (IsPossibleExportDirectory(export_dir)) {
@ -368,6 +372,9 @@ Status LoadSessionBundleOrSavedModelBundle(
LOG(INFO) << "Attempting to up-convert SessionBundle to SavedModelBundle "
"in bundle-shim from: "
<< export_dir;
if (is_session_bundle != nullptr) {
*is_session_bundle = true;
}
return LoadSavedModelFromLegacySessionBundlePath(
session_options, run_options, export_dir, saved_model_bundle);
}

View File

@ -59,11 +59,13 @@ Status ConvertSessionBundleToSavedModelBundle(
} // namespace internal
// Loads a SavedModel from either a session-bundle path or a SavedModel bundle
// path.
// path. If `is_session_bundle` is not a nullptr, sets it to `true` iff
// SavedModel was up-converted and loaded from a SessionBundle.
// `is_session_bundle` value should not be used if error is returned.
Status LoadSessionBundleOrSavedModelBundle(
const SessionOptions& session_options, const RunOptions& run_options,
const string& export_dir, const std::unordered_set<string>& tags,
SavedModelBundle* bundle);
SavedModelBundle* bundle, bool* is_session_bundle = nullptr);
} // namespace serving
} // namespace tensorflow

View File

@ -63,12 +63,16 @@ void ValidateHalfPlusTwo(const SavedModelBundle& saved_model_bundle,
void LoadAndValidateSavedModelBundle(const string& export_dir,
const std::unordered_set<string>& tags,
const string& signature_def_key) {
const string& signature_def_key,
bool expect_session_bundle) {
SessionOptions session_options;
RunOptions run_options;
SavedModelBundle saved_model_bundle;
bool is_session_bundle = false;
TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(
session_options, run_options, export_dir, tags, &saved_model_bundle));
session_options, run_options, export_dir, tags, &saved_model_bundle,
&is_session_bundle));
EXPECT_EQ(expect_session_bundle, is_session_bundle);
const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
const auto& signature_def_map = meta_graph_def.signature_def();
@ -512,7 +516,8 @@ TEST(BundleShimTest, BasicExportSessionBundle) {
const string session_bundle_export_dir =
test_util::TestSrcDirPath(kSessionBundlePath);
LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
kDefaultServingSignatureDefKey);
kDefaultServingSignatureDefKey,
/*expect_session_bundle=*/true);
// Verify that the named signature is also present.
SessionOptions session_options;
@ -558,7 +563,8 @@ TEST(BundleShimTest, BasicExportSavedModel) {
const string saved_model_bundle_export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath);
LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir,
{kSavedModelTagServe}, "regress_x_to_y");
{kSavedModelTagServe}, "regress_x_to_y",
/*expect_session_bundle=*/false);
}
// Checks a basic load fails with an invalid export path.