Adds a simple util to build a GrapplerItem from a MetaGraphDef stored in a file.

PiperOrigin-RevId: 216622520
This commit is contained in:
A. Unique TensorFlower 2018-10-10 17:30:27 -07:00 committed by TensorFlower Gardener
parent 331683cb22
commit 2b010f2e48
5 changed files with 84 additions and 6 deletions

View File

@ -630,5 +630,14 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
return new_item;
}
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
const string& id, const string& meta_graph_file, const ItemConfig& cfg) {
MetaGraphDef meta_graph;
if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) {
return nullptr;
}
return GrapplerItemFromMetaGraphDef(id, meta_graph, cfg);
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -58,6 +58,12 @@ struct ItemConfig {
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg);
// Factory method for creating a GrapplerItem from a file
// containing a MetaGraphDef in either binary or text format.
// Returns nullptr if the given meta_graph cannot be converted.
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
const string& id, const string& meta_graph_file, const ItemConfig& cfg);
} // end namespace grappler
} // end namespace tensorflow

View File

@ -35,11 +35,19 @@ bool FileExists(const string& file, Status* status) {
return status->ok();
}
Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
GraphDef* result) {
Status ReadGraphDefFromFile(const string& graph_def_path, GraphDef* result) {
Status status;
if (FileExists(graph_def_pbtxt_path, &status)) {
return ReadTextProto(Env::Default(), graph_def_pbtxt_path, result);
if (!ReadBinaryProto(Env::Default(), graph_def_path, result).ok()) {
return ReadTextProto(Env::Default(), graph_def_path, result);
}
return status;
}
Status ReadMetaGraphDefFromFile(const string& graph_def_path,
MetaGraphDef* result) {
Status status;
if (!ReadBinaryProto(Env::Default(), graph_def_path, result).ok()) {
return ReadTextProto(Env::Default(), graph_def_path, result);
}
return status;
}

View File

@ -20,7 +20,9 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
namespace grappler {
@ -31,8 +33,12 @@ bool FilesExist(const std::set<string>& files);
bool FileExists(const string& file, Status* status);
Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
GraphDef* result);
// Reads GraphDef from file in either text or raw serialized format.
Status ReadGraphDefFromFile(const string& graph_def_path, GraphDef* result);
// Reads MetaGraphDef from file in either text or raw serialized format.
Status ReadMetaGraphDefFromFile(const string& meta_graph_def_path,
MetaGraphDef* result);
} // end namespace grappler
} // end namespace tensorflow

View File

@ -31,6 +31,25 @@ class UtilsTest : public ::testing::Test {
non_existent_file_ = io::JoinPath(BaseDir(), "non_existent_file.txt");
actual_file_ = io::JoinPath(BaseDir(), "test_file.txt");
TF_CHECK_OK(WriteStringToFile(env_, actual_file_, "Some test data"));
text_graph_def_file_ = io::JoinPath(BaseDir(), "text_graph_def_file.txt");
binary_graph_def_file_ =
io::JoinPath(BaseDir(), "binary_graph_def_file.txt");
text_meta_graph_def_file_ =
io::JoinPath(BaseDir(), "text_meta_graph_def_file.txt");
binary_meta_graph_def_file_ =
io::JoinPath(BaseDir(), "binary_meta_graph_def_file.txt");
auto node = graph_def_.add_node();
node->set_name("foo");
node->set_op("bar");
TF_CHECK_OK(WriteTextProto(env_, text_graph_def_file_, graph_def_));
TF_CHECK_OK(WriteBinaryProto(env_, binary_graph_def_file_, graph_def_));
*meta_graph_def_.mutable_graph_def() = graph_def_;
TF_CHECK_OK(
WriteTextProto(env_, text_meta_graph_def_file_, meta_graph_def_));
TF_CHECK_OK(
WriteBinaryProto(env_, binary_meta_graph_def_file_, meta_graph_def_));
}
void TearDown() override {
@ -39,8 +58,14 @@ class UtilsTest : public ::testing::Test {
env_->DeleteRecursively(BaseDir(), &undeleted_files, &undeleted_dirs));
}
GraphDef graph_def_;
MetaGraphDef meta_graph_def_;
string non_existent_file_;
string actual_file_;
string text_graph_def_file_;
string binary_graph_def_file_;
string text_meta_graph_def_file_;
string binary_meta_graph_def_file_;
Env* env_ = Env::Default();
};
@ -58,6 +83,30 @@ TEST_F(UtilsTest, FilesExist) {
EXPECT_TRUE(status[1].ok());
}
TEST_F(UtilsTest, ReadGraphDefFromFile_Text) {
GraphDef result;
TF_CHECK_OK(ReadGraphDefFromFile(text_graph_def_file_, &result));
EXPECT_EQ(result.DebugString(), graph_def_.DebugString());
}
TEST_F(UtilsTest, ReadGraphDefFromFile_Binary) {
GraphDef result;
TF_CHECK_OK(ReadGraphDefFromFile(binary_graph_def_file_, &result));
EXPECT_EQ(result.DebugString(), graph_def_.DebugString());
}
TEST_F(UtilsTest, ReadMetaGraphDefFromFile_Text) {
MetaGraphDef result;
TF_CHECK_OK(ReadMetaGraphDefFromFile(text_meta_graph_def_file_, &result));
EXPECT_EQ(result.DebugString(), meta_graph_def_.DebugString());
}
TEST_F(UtilsTest, ReadReadMetaGraphDefFromFile_Binary) {
MetaGraphDef result;
TF_CHECK_OK(ReadMetaGraphDefFromFile(binary_meta_graph_def_file_, &result));
EXPECT_EQ(result.DebugString(), meta_graph_def_.DebugString());
}
} // namespace
} // namespace grappler
} // namespace tensorflow