TF SavedModel: Split off a reader from the loader module
PiperOrigin-RevId: 204468340
This commit is contained in:
parent
895a766788
commit
98010279f4
@ -33,6 +33,35 @@ cc_library(
|
|||||||
hdrs = ["tag_constants.h"],
|
hdrs = ["tag_constants.h"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "reader",
|
||||||
|
srcs = ["reader.cc"],
|
||||||
|
hdrs = ["reader.h"],
|
||||||
|
deps = [
|
||||||
|
":constants",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "reader_test",
|
||||||
|
srcs = ["reader_test.cc"],
|
||||||
|
data = [
|
||||||
|
":saved_model_half_plus_two",
|
||||||
|
],
|
||||||
|
linkstatic = 1,
|
||||||
|
deps = [
|
||||||
|
":constants",
|
||||||
|
":reader",
|
||||||
|
":tag_constants",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "loader",
|
name = "loader",
|
||||||
hdrs = ["loader.h"],
|
hdrs = ["loader.h"],
|
||||||
@ -54,6 +83,7 @@ cc_library(
|
|||||||
hdrs = ["loader.h"],
|
hdrs = ["loader.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":constants",
|
":constants",
|
||||||
|
":reader",
|
||||||
] + if_not_mobile([
|
] + if_not_mobile([
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -18,8 +18,10 @@ limitations under the License.
|
|||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
#include "tensorflow/cc/saved_model/constants.h"
|
#include "tensorflow/cc/saved_model/constants.h"
|
||||||
|
#include "tensorflow/cc/saved_model/reader.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||||
@ -43,56 +45,6 @@ auto* load_latency = monitoring::Counter<1>::New(
|
|||||||
constexpr char kLoadAttemptFail[] = "fail";
|
constexpr char kLoadAttemptFail[] = "fail";
|
||||||
constexpr char kLoadAttemptSuccess[] = "success";
|
constexpr char kLoadAttemptSuccess[] = "success";
|
||||||
|
|
||||||
Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
|
|
||||||
const string saved_model_pb_path =
|
|
||||||
io::JoinPath(export_dir, kSavedModelFilenamePb);
|
|
||||||
if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
|
|
||||||
return ReadBinaryProto(Env::Default(), saved_model_pb_path,
|
|
||||||
saved_model_proto);
|
|
||||||
}
|
|
||||||
const string saved_model_pbtxt_path =
|
|
||||||
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
|
|
||||||
if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
|
|
||||||
return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
|
|
||||||
saved_model_proto);
|
|
||||||
}
|
|
||||||
return Status(error::Code::NOT_FOUND,
|
|
||||||
"Could not find SavedModel .pb or .pbtxt at supplied export "
|
|
||||||
"directory path: " +
|
|
||||||
export_dir);
|
|
||||||
}
|
|
||||||
|
|
||||||
string GetTagsAsString(const std::unordered_set<string>& tags) {
|
|
||||||
string tags_as_string = "{ ";
|
|
||||||
for (const string& tag : tags) {
|
|
||||||
tags_as_string = strings::StrCat(tags_as_string, tag, " ");
|
|
||||||
}
|
|
||||||
tags_as_string = strings::StrCat(tags_as_string, "}");
|
|
||||||
return tags_as_string;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto,
|
|
||||||
const std::unordered_set<string>& tags,
|
|
||||||
MetaGraphDef* meta_graph_def_to_load) {
|
|
||||||
for (const MetaGraphDef& meta_graph_def : saved_model_proto.meta_graphs()) {
|
|
||||||
// Get tags from the meta_graph_def.
|
|
||||||
std::unordered_set<string> graph_tags;
|
|
||||||
for (const string& tag : meta_graph_def.meta_info_def().tags()) {
|
|
||||||
graph_tags.insert(tag);
|
|
||||||
}
|
|
||||||
// Match with the set of tags provided.
|
|
||||||
if (graph_tags == tags) {
|
|
||||||
*meta_graph_def_to_load = meta_graph_def;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status(error::Code::NOT_FOUND,
|
|
||||||
"Could not find meta graph def matching supplied tags: " +
|
|
||||||
GetTagsAsString(tags) +
|
|
||||||
". To inspect available tag-sets in the SavedModel, please "
|
|
||||||
"use the SavedModel CLI: `saved_model_cli`");
|
|
||||||
}
|
|
||||||
|
|
||||||
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def,
|
||||||
const SessionOptions& session_options,
|
const SessionOptions& session_options,
|
||||||
std::unique_ptr<Session>* session) {
|
std::unique_ptr<Session>* session) {
|
||||||
@ -235,18 +187,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
|||||||
const string& export_dir,
|
const string& export_dir,
|
||||||
const std::unordered_set<string>& tags,
|
const std::unordered_set<string>& tags,
|
||||||
SavedModelBundle* const bundle) {
|
SavedModelBundle* const bundle) {
|
||||||
if (!MaybeSavedModelDirectory(export_dir)) {
|
TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
|
||||||
return Status(error::Code::NOT_FOUND,
|
&bundle->meta_graph_def));
|
||||||
"SavedModel not found in export directory: " + export_dir);
|
|
||||||
}
|
|
||||||
LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags)
|
|
||||||
<< "; from: " << export_dir;
|
|
||||||
|
|
||||||
SavedModel saved_model_proto;
|
|
||||||
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
FindMetaGraphDefToLoad(saved_model_proto, tags, &bundle->meta_graph_def));
|
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
|
TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession(
|
||||||
bundle->meta_graph_def, session_options, &bundle->session));
|
bundle->meta_graph_def, session_options, &bundle->session));
|
||||||
@ -288,8 +230,8 @@ Status LoadSavedModel(const SessionOptions& session_options,
|
|||||||
return end_microseconds - start_microseconds;
|
return end_microseconds - start_microseconds;
|
||||||
}();
|
}();
|
||||||
auto log_and_count = [&](const string& status_str) {
|
auto log_and_count = [&](const string& status_str) {
|
||||||
LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags)
|
LOG(INFO) << "SavedModel load for tags { " << str_util::Join(tags, " ")
|
||||||
<< "; Status: " << status_str << ". Took "
|
<< " }; Status: " << status_str << ". Took "
|
||||||
<< load_latency_microsecs << " microseconds.";
|
<< load_latency_microsecs << " microseconds.";
|
||||||
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
|
load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
|
||||||
};
|
};
|
||||||
|
88
tensorflow/cc/saved_model/reader.cc
Normal file
88
tensorflow/cc/saved_model/reader.cc
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/saved_model/reader.h"
|
||||||
|
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "tensorflow/cc/saved_model/constants.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
|
#include "tensorflow/core/protobuf/saved_model.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) {
|
||||||
|
LOG(INFO) << "Reading SavedModel from: " << export_dir;
|
||||||
|
|
||||||
|
const string saved_model_pb_path =
|
||||||
|
io::JoinPath(export_dir, kSavedModelFilenamePb);
|
||||||
|
if (Env::Default()->FileExists(saved_model_pb_path).ok()) {
|
||||||
|
return ReadBinaryProto(Env::Default(), saved_model_pb_path,
|
||||||
|
saved_model_proto);
|
||||||
|
}
|
||||||
|
const string saved_model_pbtxt_path =
|
||||||
|
io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
|
||||||
|
if (Env::Default()->FileExists(saved_model_pbtxt_path).ok()) {
|
||||||
|
return ReadTextProto(Env::Default(), saved_model_pbtxt_path,
|
||||||
|
saved_model_proto);
|
||||||
|
}
|
||||||
|
return Status(error::Code::NOT_FOUND,
|
||||||
|
"Could not find SavedModel .pb or .pbtxt at supplied export "
|
||||||
|
"directory path: " +
|
||||||
|
export_dir);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status FindMetaGraphDef(const SavedModel& saved_model_proto,
|
||||||
|
const std::unordered_set<string>& tags,
|
||||||
|
MetaGraphDef* meta_graph_def) {
|
||||||
|
LOG(INFO) << "Reading meta graph with tags { " << str_util::Join(tags, " ")
|
||||||
|
<< " }";
|
||||||
|
for (const MetaGraphDef& graph_def : saved_model_proto.meta_graphs()) {
|
||||||
|
// Get tags from the graph_def.
|
||||||
|
std::unordered_set<string> graph_tags;
|
||||||
|
for (const string& tag : graph_def.meta_info_def().tags()) {
|
||||||
|
graph_tags.insert(tag);
|
||||||
|
}
|
||||||
|
// Match with the set of tags provided.
|
||||||
|
if (graph_tags == tags) {
|
||||||
|
*meta_graph_def = graph_def;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status(
|
||||||
|
error::Code::NOT_FOUND,
|
||||||
|
strings::StrCat(
|
||||||
|
"Could not find meta graph def matching supplied tags: { ",
|
||||||
|
str_util::Join(tags, " "),
|
||||||
|
" }. To inspect available tag-sets in the SavedModel, please "
|
||||||
|
"use the SavedModel CLI: `saved_model_cli`"));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
|
||||||
|
const std::unordered_set<string>& tags,
|
||||||
|
MetaGraphDef* const meta_graph_def) {
|
||||||
|
SavedModel saved_model_proto;
|
||||||
|
TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto));
|
||||||
|
TF_RETURN_IF_ERROR(FindMetaGraphDef(saved_model_proto, tags, meta_graph_def));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
39
tensorflow/cc/saved_model/reader.h
Normal file
39
tensorflow/cc/saved_model/reader.h
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
/// Functions to read the SavedModel proto, or parts of it.
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CC_SAVED_MODEL_READER_H_
|
||||||
|
#define TENSORFLOW_CC_SAVED_MODEL_READER_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Reads the SavedModel proto from saved_model.pb(txt) in the given directory,
|
||||||
|
// finds the MetaGraphDef that matches the given set of tags and writes it to
|
||||||
|
// the `meta_graph_def` parameter. Returns a failure status when the SavedModel
|
||||||
|
// file does not exist or no MetaGraphDef matches the tags.
|
||||||
|
Status ReadMetaGraphDefFromSavedModel(const string& export_dir,
|
||||||
|
const std::unordered_set<string>& tags,
|
||||||
|
MetaGraphDef* const meta_graph_def);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_
|
108
tensorflow/cc/saved_model/reader_test.cc
Normal file
108
tensorflow/cc/saved_model/reader_test.cc
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/cc/saved_model/reader.h"
|
||||||
|
|
||||||
|
#include "tensorflow/cc/saved_model/constants.h"
|
||||||
|
#include "tensorflow/cc/saved_model/tag_constants.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
constexpr char kTestDataPbTxt[] =
|
||||||
|
"cc/saved_model/testdata/half_plus_two_pbtxt/00000123";
|
||||||
|
constexpr char kTestDataSharded[] =
|
||||||
|
"cc/saved_model/testdata/half_plus_two/00000123";
|
||||||
|
|
||||||
|
class ReaderTest : public ::testing::Test {
|
||||||
|
protected:
|
||||||
|
ReaderTest() {}
|
||||||
|
|
||||||
|
void CheckMetaGraphDef(const MetaGraphDef& meta_graph_def) {
|
||||||
|
const auto& tags = meta_graph_def.meta_info_def().tags();
|
||||||
|
EXPECT_TRUE(std::find(tags.begin(), tags.end(), kSavedModelTagServe) !=
|
||||||
|
tags.end());
|
||||||
|
EXPECT_NE(meta_graph_def.meta_info_def().tensorflow_version(), "");
|
||||||
|
EXPECT_EQ(
|
||||||
|
meta_graph_def.signature_def().at("serving_default").method_name(),
|
||||||
|
"tensorflow/serving/predict");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(ReaderTest, TagMatch) {
|
||||||
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
|
const string export_dir =
|
||||||
|
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||||
|
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||||
|
&meta_graph_def));
|
||||||
|
CheckMetaGraphDef(meta_graph_def);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ReaderTest, NoTagMatch) {
|
||||||
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
|
const string export_dir =
|
||||||
|
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||||
|
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {"missing-tag"},
|
||||||
|
&meta_graph_def);
|
||||||
|
EXPECT_FALSE(st.ok());
|
||||||
|
EXPECT_TRUE(str_util::StrContains(
|
||||||
|
st.error_message(),
|
||||||
|
"Could not find meta graph def matching supplied tags: { missing-tag }"))
|
||||||
|
<< st.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ReaderTest, NoTagMatchMultiple) {
|
||||||
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
|
const string export_dir =
|
||||||
|
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded);
|
||||||
|
Status st = ReadMetaGraphDefFromSavedModel(
|
||||||
|
export_dir, {kSavedModelTagServe, "missing-tag"}, &meta_graph_def);
|
||||||
|
EXPECT_FALSE(st.ok());
|
||||||
|
EXPECT_TRUE(str_util::StrContains(
|
||||||
|
st.error_message(),
|
||||||
|
"Could not find meta graph def matching supplied tags: "))
|
||||||
|
<< st.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ReaderTest, PbtxtFormat) {
|
||||||
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
|
const string export_dir =
|
||||||
|
io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPbTxt);
|
||||||
|
TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||||
|
&meta_graph_def));
|
||||||
|
CheckMetaGraphDef(meta_graph_def);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ReaderTest, InvalidExportPath) {
|
||||||
|
MetaGraphDef meta_graph_def;
|
||||||
|
|
||||||
|
const string export_dir =
|
||||||
|
io::JoinPath(testing::TensorFlowSrcRoot(), "missing-path");
|
||||||
|
Status st = ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe},
|
||||||
|
&meta_graph_def);
|
||||||
|
EXPECT_FALSE(st.ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -47,7 +47,7 @@ public class SavedModelBundleTest {
|
|||||||
fail("not expected");
|
fail("not expected");
|
||||||
} catch (org.tensorflow.TensorFlowException e) {
|
} catch (org.tensorflow.TensorFlowException e) {
|
||||||
// expected exception
|
// expected exception
|
||||||
assertTrue(e.getMessage().contains("SavedModel not found"));
|
assertTrue(e.getMessage().contains("Could not find SavedModel"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user