TF SavedModel: Split off a reader from the loader module

PiperOrigin-RevId: 204468340
This commit is contained in:
A. Unique TensorFlower 2018-07-13 07:21:37 -07:00 committed by TensorFlower Gardener
parent 895a766788
commit 98010279f4
6 changed files with 272 additions and 65 deletions

View File

@ -33,6 +33,35 @@ cc_library(
hdrs = ["tag_constants.h"],
)
cc_library(
name = "reader",
srcs = ["reader.cc"],
hdrs = ["reader.h"],
deps = [
":constants",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
)
tf_cc_test(
name = "reader_test",
srcs = ["reader_test.cc"],
data = [
":saved_model_half_plus_two",
],
linkstatic = 1,
deps = [
":constants",
":reader",
":tag_constants",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "loader",
hdrs = ["loader.h"],
@ -54,6 +83,7 @@ cc_library(
hdrs = ["loader.h"],
deps = [
":constants",
":reader",
] + if_not_mobile([
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",

View File

@ -18,8 +18,10 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf_internal.h"
@ -43,56 +45,6 @@ auto* load_latency = monitoring::Counter<1>::New(
constexpr char kLoadAttemptFail[] = "fail";
constexpr char kLoadAttemptSuccess[] = "success";
Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
const string saved_model_pb_path =
io::JoinPath(export_dir, kSavedModelFilenamePb);
if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
return ReadBinaryProto(Env::Default(), saved_model_pb_path,
saved_model_proto);
}
const string saved_model_pbtxt_path =
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
saved_model_proto);
}
return Status(error::Code::NOT_FOUND,
"Could not find SavedModel .pb or .pbtxt at supplied export "
"directory path: " +
export_dir);
}
string GetTagsAsString(const std::unordered_set<string>& tags) {
string tags_as_string = "{ ";
for (const string& tag : tags) {
tags_as_string = strings::StrCat(tags_as_string, tag, " ");
}
tags_as_string = strings::StrCat(tags_as_string, "}");
return tags_as_string;
}
Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
const std::unordered_set<string>& tags,
MetaGraphDef* meta_graph_def_to_load) {
for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) {
// Get tags from the meta_graph_def.
std::unordered_set<string> graph_tags;
for (const string& tag : meta_graph_def.meta_info_def().tags()) {
graph_tags.insert(tag);
}
// Match with the set of tags provided.
if (graph_tags == tags) {
*meta_graph_def_to_load = meta_graph_def;
return Status::OK();
}
}
return Status(error::Code::NOT_FOUND,
"Could not find meta graph def matching supplied tags: " +
GetTagsAsString(tags) +
". To inspect available tag-sets in the SavedModel, please "
"use the SavedModel CLI: `saved_model_cli`");
}
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
const SessionOptions& session_options,
std::unique_ptr<Session>* session) {
@ -235,18 +187,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
const string& export_dir,
const std::unordered_set<string>& tags,
SavedModelBundle* const bundle) {
if (!MaybeSavedModelDirectory(export_dir)) {
return Status(error::Code::NOT_FOUND,
"SavedModel not found in export directory: " + export_dir);
}
LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags)
<< "; from: " << export_dir;
SavedModel saved_model_proto;
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
TF_RETURN_IF_ERROR(
FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def));
TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
&bundle->meta_graph_def));
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
bundle->meta_graph_def, session_options, &bundle->session));
@ -288,8 +230,8 @@ Status LoadSavedModel(const SessionOptions& session_options,
return end_microseconds - start_microseconds;
}();
auto log_and_count = [&](const string& status_str) {
LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags)
<< "; Status: " << status_str << ". Took "
LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ")
<< " }; Status: " << status_str << ". Took "
<< load_latency_microsecs << " microseconds.";
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
};

View File

@ -0,0 +1,88 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/reader.h"
#include <unordered_set>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/protobuf/saved_model.pb.h"
namespace tensorflow {
namespace {
Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
LOG(INFO) << "Reading SavedModel from: " << export_dir;
const string saved_model_pb_path =
io::JoinPath(export_dir, kSavedModelFilenamePb);
if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
return ReadBinaryProto(Env::Default(), saved_model_pb_path,
saved_model_proto);
}
const string saved_model_pbtxt_path =
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
saved_model_proto);
}
return Status(error::Code::NOT_FOUND,
"Could not find SavedModel .pb or .pbtxt at supplied export "
"directory path: " +
export_dir);
}
Status FindMetaGraphDef(const SavedModel& saved_model_proto,
const std::unordered_set<string>& tags,
MetaGraphDef* meta_graph_def) {
LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ")
<< " }";
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
// Get tags from the graph_def.
std::unordered_set<string> graph_tags;
for (const string& tag : graph_def.meta_info_def().tags()) {
graph_tags.insert(tag);
}
// Match with the set of tags provided.
if (graph_tags == tags) {
*meta_graph_def = graph_def;
return Status::OK();
}
}
return Status(
error::Code::NOT_FOUND,
strings::StrCat(
"Could not find meta graph def matching supplied tags: { ",
str_util::Join(tags, " "),
" }. To inspect available tag-sets in the SavedModel, please "
"use the SavedModel CLI: `saved_model_cli`"));
}
} // namespace
Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
const std::unordered_set<string>& tags,
MetaGraphDef* const meta_graph_def) {
SavedModel saved_model_proto;
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
TF_RETURN_IF_ERROR(FindMetaGraphDef(saved_model_proto, tags, meta_graph_def));
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/// Functions to read the SavedModel proto, or parts of it.
#ifndef TENSORFLOW_CC_SAVED_MODEL_READER_H_
#define TENSORFLOW_CC_SAVED_MODEL_READER_H_
#include <string>
#include <unordered_set>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
// Reads the SavedModel proto from saved_model.pb(txt) in the given directory,
// finds the MetaGraphDef that matches the given set of tags and writes it to
// the `meta_graph_def` parameter. Returns a failure status when the SavedModel
// file does not exist or no MetaGraphDef matches the tags.
Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
const std::unordered_set<string>& tags,
MetaGraphDef* const meta_graph_def);
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_

View File

@ -0,0 +1,108 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/tag_constants.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
constexpr char kTestDataPbTxt[] =
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
constexpr char kTestDataSharded[] =
"cc/saved_model/testdata/half_plus_two/00000123";
class ReaderTest : public ::testing::Test {
protected:
ReaderTest() {}
void CheckMetaGraphDef(const MetaGraphDef& meta_graph_def) {
const auto& tags = meta_graph_def.meta_info_def().tags();
EXPECT_TRUE(std::find(tags.begin(), tags.end(), kSavedModelTagServe) !=
tags.end());
EXPECT_NE(meta_graph_def.meta_info_def().tensorflow_version(), "");
EXPECT_EQ(
meta_graph_def.signature_def().at("serving_default").method_name(),
"tensorflow/serving/predict");
}
};
TEST_F(ReaderTest, TagMatch) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def));
CheckMetaGraphDef(meta_graph_def);
}
TEST_F(ReaderTest, NoTagMatch) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
&meta_graph_def);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(str_util::StrContains(
st.error_message(),
"Could not find meta graph def matching supplied tags: { missing-tag }"))
<< st.error_message();
}
TEST_F(ReaderTest, NoTagMatchMultiple) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
Status st = ReadMetaGraphDefFromSavedModel(
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
EXPECT_FALSE(st.ok());
EXPECT_TRUE(str_util::StrContains(
st.error_message(),
"Could not find meta graph def matching supplied tags: "))
<< st.error_message();
}
TEST_F(ReaderTest, PbtxtFormat) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def));
CheckMetaGraphDef(meta_graph_def);
}
TEST_F(ReaderTest, InvalidExportPath) {
MetaGraphDef meta_graph_def;
const string export_dir =
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
&meta_graph_def);
EXPECT_FALSE(st.ok());
}
} // namespace
} // namespace tensorflow

View File

@ -47,7 +47,7 @@ public class SavedModelBundleTest {
fail("not expected");
} catch (org.tensorflow.TensorFlowException e) {
// expected exception
assertTrue(e.getMessage().contains("SavedModel not found"));
assertTrue(e.getMessage().contains("Could not find SavedModel"));
}
}
}