diff --git a/tensorflow/tools/benchmark/benchmark_model.cc b/tensorflow/tools/benchmark/benchmark_model.cc index 15523028c72..eeb1fab40c4 100644 --- a/tensorflow/tools/benchmark/benchmark_model.cc +++ b/tensorflow/tools/benchmark/benchmark_model.cc @@ -261,6 +261,10 @@ Status InitializeSession(int num_threads, const string& graph, graph_def->reset(new GraphDef()); tensorflow::GraphDef tensorflow_graph; Status s = ReadBinaryProto(Env::Default(), graph, graph_def->get()); + if (!s.ok()) { + s = ReadTextProto(Env::Default(), graph, graph_def->get()); + } + if (!s.ok()) { LOG(ERROR) << "Could not create TensorFlow Graph: " << s; return s; diff --git a/tensorflow/tools/benchmark/benchmark_model_test.cc b/tensorflow/tools/benchmark/benchmark_model_test.cc index 16ab2ff66e7..6813045d632 100644 --- a/tensorflow/tools/benchmark/benchmark_model_test.cc +++ b/tensorflow/tools/benchmark/benchmark_model_test.cc @@ -26,30 +26,36 @@ limitations under the License. namespace tensorflow { namespace { -TEST(BenchmarkModelTest, InitializeAndRun) { - const string dir = testing::TmpDir(); - const string filename_pb = io::JoinPath(dir, "graphdef.pb"); - +void CreateTestGraph(const ::tensorflow::Scope& root, + benchmark_model::InputLayerInfo* input, + string* output_name, GraphDef* graph_def) { // Create a simple graph and write it to filename_pb. const int input_width = 400; const int input_height = 10; - benchmark_model::InputLayerInfo input; - input.shape = TensorShape({input_width, input_height}); - input.data_type = DT_FLOAT; + input->shape = TensorShape({input_width, input_height}); + input->data_type = DT_FLOAT; const TensorShape constant_shape({input_height, input_width}); Tensor constant_tensor(DT_FLOAT, constant_shape); test::FillFn(&constant_tensor, [](int) -> float { return 3.0; }); - auto root = Scope::NewRootScope().ExitOnError(); auto placeholder = - ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input.shape)); - input.name = placeholder.node()->name(); + ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input->shape)); + input->name = placeholder.node()->name(); auto m = ops::MatMul(root, placeholder, constant_tensor); - const string output_name = m.node()->name(); + *output_name = m.node()->name(); + TF_ASSERT_OK(root.ToGraphDef(graph_def)); +} +TEST(BenchmarkModelTest, InitializeAndRun) { + const string dir = testing::TmpDir(); + const string filename_pb = io::JoinPath(dir, "graphdef.pb"); + auto root = Scope::NewRootScope().ExitOnError(); + + benchmark_model::InputLayerInfo input; + string output_name; GraphDef graph_def; - TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + CreateTestGraph(root, &input, &output_name, &graph_def); string graph_def_serialized; graph_def.SerializeToString(&graph_def_serialized); TF_ASSERT_OK( @@ -69,5 +75,30 @@ TEST(BenchmarkModelTest, InitializeAndRun) { ASSERT_EQ(num_runs, 10); } +TEST(BenchmarkModeTest, TextProto) { + const string dir = testing::TmpDir(); + const string filename_txt = io::JoinPath(dir, "graphdef.pb.txt"); + auto root = Scope::NewRootScope().ExitOnError(); + + benchmark_model::InputLayerInfo input; + string output_name; + GraphDef graph_def; + CreateTestGraph(root, &input, &output_name, &graph_def); + TF_ASSERT_OK(WriteTextProto(Env::Default(), filename_txt, graph_def)); + + std::unique_ptr session; + std::unique_ptr loaded_graph_def; + TF_ASSERT_OK(benchmark_model::InitializeSession(1, filename_txt, &session, + &loaded_graph_def)); + std::unique_ptr stats; + stats.reset(new tensorflow::StatSummarizer(*(loaded_graph_def.get()))); + int64 time; + int64 num_runs = 0; + TF_ASSERT_OK(benchmark_model::TimeMultipleRuns( + 0.0, 10, 0.0, {input}, {output_name}, {}, session.get(), stats.get(), + &time, &num_runs)); + ASSERT_EQ(num_runs, 10); +} + } // namespace } // namespace tensorflow