Adds a simple util to build a GrapplerItem from a MetaGraphDef stored in a file.
PiperOrigin-RevId: 216622520
This commit is contained in:
parent
331683cb22
commit
2b010f2e48
@ -630,5 +630,14 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
return new_item;
|
||||
}
|
||||
|
||||
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
|
||||
const string& id, const string& meta_graph_file, const ItemConfig& cfg) {
|
||||
MetaGraphDef meta_graph;
|
||||
if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return GrapplerItemFromMetaGraphDef(id, meta_graph, cfg);
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -58,6 +58,12 @@ struct ItemConfig {
|
||||
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg);
|
||||
|
||||
// Factory method for creating a GrapplerItem from a file
|
||||
// containing a MetaGraphDef in either binary or text format.
|
||||
// Returns nullptr if the given meta_graph cannot be converted.
|
||||
std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
|
||||
const string& id, const string& meta_graph_file, const ItemConfig& cfg);
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -35,11 +35,19 @@ bool FileExists(const string& file, Status* status) {
|
||||
return status->ok();
|
||||
}
|
||||
|
||||
Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
|
||||
GraphDef* result) {
|
||||
Status ReadGraphDefFromFile(const string& graph_def_path, GraphDef* result) {
|
||||
Status status;
|
||||
if (FileExists(graph_def_pbtxt_path, &status)) {
|
||||
return ReadTextProto(Env::Default(), graph_def_pbtxt_path, result);
|
||||
if (!ReadBinaryProto(Env::Default(), graph_def_path, result).ok()) {
|
||||
return ReadTextProto(Env::Default(), graph_def_path, result);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
Status ReadMetaGraphDefFromFile(const string& graph_def_path,
|
||||
MetaGraphDef* result) {
|
||||
Status status;
|
||||
if (!ReadBinaryProto(Env::Default(), graph_def_path, result).ok()) {
|
||||
return ReadTextProto(Env::Default(), graph_def_path, result);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
@ -20,7 +20,9 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -31,8 +33,12 @@ bool FilesExist(const std::set<string>& files);
|
||||
|
||||
bool FileExists(const string& file, Status* status);
|
||||
|
||||
Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
|
||||
GraphDef* result);
|
||||
// Reads GraphDef from file in either text or raw serialized format.
|
||||
Status ReadGraphDefFromFile(const string& graph_def_path, GraphDef* result);
|
||||
|
||||
// Reads MetaGraphDef from file in either text or raw serialized format.
|
||||
Status ReadMetaGraphDefFromFile(const string& meta_graph_def_path,
|
||||
MetaGraphDef* result);
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -31,6 +31,25 @@ class UtilsTest : public ::testing::Test {
|
||||
non_existent_file_ = io::JoinPath(BaseDir(), "non_existent_file.txt");
|
||||
actual_file_ = io::JoinPath(BaseDir(), "test_file.txt");
|
||||
TF_CHECK_OK(WriteStringToFile(env_, actual_file_, "Some test data"));
|
||||
|
||||
text_graph_def_file_ = io::JoinPath(BaseDir(), "text_graph_def_file.txt");
|
||||
binary_graph_def_file_ =
|
||||
io::JoinPath(BaseDir(), "binary_graph_def_file.txt");
|
||||
text_meta_graph_def_file_ =
|
||||
io::JoinPath(BaseDir(), "text_meta_graph_def_file.txt");
|
||||
binary_meta_graph_def_file_ =
|
||||
io::JoinPath(BaseDir(), "binary_meta_graph_def_file.txt");
|
||||
|
||||
auto node = graph_def_.add_node();
|
||||
node->set_name("foo");
|
||||
node->set_op("bar");
|
||||
TF_CHECK_OK(WriteTextProto(env_, text_graph_def_file_, graph_def_));
|
||||
TF_CHECK_OK(WriteBinaryProto(env_, binary_graph_def_file_, graph_def_));
|
||||
*meta_graph_def_.mutable_graph_def() = graph_def_;
|
||||
TF_CHECK_OK(
|
||||
WriteTextProto(env_, text_meta_graph_def_file_, meta_graph_def_));
|
||||
TF_CHECK_OK(
|
||||
WriteBinaryProto(env_, binary_meta_graph_def_file_, meta_graph_def_));
|
||||
}
|
||||
|
||||
void TearDown() override {
|
||||
@ -39,8 +58,14 @@ class UtilsTest : public ::testing::Test {
|
||||
env_->DeleteRecursively(BaseDir(), &undeleted_files, &undeleted_dirs));
|
||||
}
|
||||
|
||||
GraphDef graph_def_;
|
||||
MetaGraphDef meta_graph_def_;
|
||||
string non_existent_file_;
|
||||
string actual_file_;
|
||||
string text_graph_def_file_;
|
||||
string binary_graph_def_file_;
|
||||
string text_meta_graph_def_file_;
|
||||
string binary_meta_graph_def_file_;
|
||||
Env* env_ = Env::Default();
|
||||
};
|
||||
|
||||
@ -58,6 +83,30 @@ TEST_F(UtilsTest, FilesExist) {
|
||||
EXPECT_TRUE(status[1].ok());
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, ReadGraphDefFromFile_Text) {
|
||||
GraphDef result;
|
||||
TF_CHECK_OK(ReadGraphDefFromFile(text_graph_def_file_, &result));
|
||||
EXPECT_EQ(result.DebugString(), graph_def_.DebugString());
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, ReadGraphDefFromFile_Binary) {
|
||||
GraphDef result;
|
||||
TF_CHECK_OK(ReadGraphDefFromFile(binary_graph_def_file_, &result));
|
||||
EXPECT_EQ(result.DebugString(), graph_def_.DebugString());
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, ReadMetaGraphDefFromFile_Text) {
|
||||
MetaGraphDef result;
|
||||
TF_CHECK_OK(ReadMetaGraphDefFromFile(text_meta_graph_def_file_, &result));
|
||||
EXPECT_EQ(result.DebugString(), meta_graph_def_.DebugString());
|
||||
}
|
||||
|
||||
TEST_F(UtilsTest, ReadReadMetaGraphDefFromFile_Binary) {
|
||||
MetaGraphDef result;
|
||||
TF_CHECK_OK(ReadMetaGraphDefFromFile(binary_meta_graph_def_file_, &result));
|
||||
EXPECT_EQ(result.DebugString(), meta_graph_def_.DebugString());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user