Return model format from LoadSessionBundleOrSavedModelBundle(),
allowing callers to know if we up-converted a SessionBundle to SavedModel format. PiperOrigin-RevId: 213937542
This commit is contained in:
parent
f10b00558d
commit
23552a8b2f
@ -355,11 +355,15 @@ Status LoadSessionBundleOrSavedModelBundle(
|
||||
const SessionOptions& session_options, const RunOptions& run_options,
|
||||
const string& export_dir,
|
||||
const std::unordered_set<string>& saved_model_tags,
|
||||
SavedModelBundle* saved_model_bundle) {
|
||||
SavedModelBundle* saved_model_bundle, bool* is_session_bundle) {
|
||||
if (is_session_bundle != nullptr) {
|
||||
*is_session_bundle = false;
|
||||
}
|
||||
if (MaybeSavedModelDirectory(export_dir)) {
|
||||
LOG(INFO)
|
||||
<< "Attempting to load native SavedModelBundle in bundle-shim from: "
|
||||
<< export_dir;
|
||||
|
||||
return LoadSavedModel(session_options, run_options, export_dir,
|
||||
saved_model_tags, saved_model_bundle);
|
||||
} else if (IsPossibleExportDirectory(export_dir)) {
|
||||
@ -368,6 +372,9 @@ Status LoadSessionBundleOrSavedModelBundle(
|
||||
LOG(INFO) << "Attempting to up-convert SessionBundle to SavedModelBundle "
|
||||
"in bundle-shim from: "
|
||||
<< export_dir;
|
||||
if (is_session_bundle != nullptr) {
|
||||
*is_session_bundle = true;
|
||||
}
|
||||
return LoadSavedModelFromLegacySessionBundlePath(
|
||||
session_options, run_options, export_dir, saved_model_bundle);
|
||||
}
|
||||
|
@ -59,11 +59,13 @@ Status ConvertSessionBundleToSavedModelBundle(
|
||||
} // namespace internal
|
||||
|
||||
// Loads a SavedModel from either a session-bundle path or a SavedModel bundle
|
||||
// path.
|
||||
// path. If `is_session_bundle` is not a nullptr, sets it to `true` iff
|
||||
// SavedModel was up-converted and loaded from a SessionBundle.
|
||||
// `is_session_bundle` value should not be used if error is returned.
|
||||
Status LoadSessionBundleOrSavedModelBundle(
|
||||
const SessionOptions& session_options, const RunOptions& run_options,
|
||||
const string& export_dir, const std::unordered_set<string>& tags,
|
||||
SavedModelBundle* bundle);
|
||||
SavedModelBundle* bundle, bool* is_session_bundle = nullptr);
|
||||
|
||||
} // namespace serving
|
||||
} // namespace tensorflow
|
||||
|
@ -63,12 +63,16 @@ void ValidateHalfPlusTwo(const SavedModelBundle& saved_model_bundle,
|
||||
|
||||
void LoadAndValidateSavedModelBundle(const string& export_dir,
|
||||
const std::unordered_set<string>& tags,
|
||||
const string& signature_def_key) {
|
||||
const string& signature_def_key,
|
||||
bool expect_session_bundle) {
|
||||
SessionOptions session_options;
|
||||
RunOptions run_options;
|
||||
SavedModelBundle saved_model_bundle;
|
||||
bool is_session_bundle = false;
|
||||
TF_ASSERT_OK(LoadSessionBundleOrSavedModelBundle(
|
||||
session_options, run_options, export_dir, tags, &saved_model_bundle));
|
||||
session_options, run_options, export_dir, tags, &saved_model_bundle,
|
||||
&is_session_bundle));
|
||||
EXPECT_EQ(expect_session_bundle, is_session_bundle);
|
||||
const MetaGraphDef meta_graph_def = saved_model_bundle.meta_graph_def;
|
||||
const auto& signature_def_map = meta_graph_def.signature_def();
|
||||
|
||||
@ -512,7 +516,8 @@ TEST(BundleShimTest, BasicExportSessionBundle) {
|
||||
const string session_bundle_export_dir =
|
||||
test_util::TestSrcDirPath(kSessionBundlePath);
|
||||
LoadAndValidateSavedModelBundle(session_bundle_export_dir, tags,
|
||||
kDefaultServingSignatureDefKey);
|
||||
kDefaultServingSignatureDefKey,
|
||||
/*expect_session_bundle=*/true);
|
||||
|
||||
// Verify that the named signature is also present.
|
||||
SessionOptions session_options;
|
||||
@ -558,7 +563,8 @@ TEST(BundleShimTest, BasicExportSavedModel) {
|
||||
const string saved_model_bundle_export_dir =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kSavedModelBundlePath);
|
||||
LoadAndValidateSavedModelBundle(saved_model_bundle_export_dir,
|
||||
{kSavedModelTagServe}, "regress_x_to_y");
|
||||
{kSavedModelTagServe}, "regress_x_to_y",
|
||||
/*expect_session_bundle=*/false);
|
||||
}
|
||||
|
||||
// Checks a basic load fails with an invalid export path.
|
||||
|
Loading…
Reference in New Issue
Block a user