Allow benchmark model graph to be specified in text proto format.
PiperOrigin-RevId: 195547670
This commit is contained in:
parent
f3c21911bc
commit
62ed0aa370
@ -261,6 +261,10 @@ Status InitializeSession(int num_threads, const string& graph,
|
||||
graph_def->reset(new GraphDef());
|
||||
tensorflow::GraphDef tensorflow_graph;
|
||||
Status s = ReadBinaryProto(Env::Default(), graph, graph_def->get());
|
||||
if (!s.ok()) {
|
||||
s = ReadTextProto(Env::Default(), graph, graph_def->get());
|
||||
}
|
||||
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << "Could not create TensorFlow Graph: " << s;
|
||||
return s;
|
||||
|
@ -26,30 +26,36 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(BenchmarkModelTest, InitializeAndRun) {
|
||||
const string dir = testing::TmpDir();
|
||||
const string filename_pb = io::JoinPath(dir, "graphdef.pb");
|
||||
|
||||
void CreateTestGraph(const ::tensorflow::Scope& root,
|
||||
benchmark_model::InputLayerInfo* input,
|
||||
string* output_name, GraphDef* graph_def) {
|
||||
// Create a simple graph and write it to filename_pb.
|
||||
const int input_width = 400;
|
||||
const int input_height = 10;
|
||||
benchmark_model::InputLayerInfo input;
|
||||
input.shape = TensorShape({input_width, input_height});
|
||||
input.data_type = DT_FLOAT;
|
||||
input->shape = TensorShape({input_width, input_height});
|
||||
input->data_type = DT_FLOAT;
|
||||
const TensorShape constant_shape({input_height, input_width});
|
||||
|
||||
Tensor constant_tensor(DT_FLOAT, constant_shape);
|
||||
test::FillFn<float>(&constant_tensor, [](int) -> float { return 3.0; });
|
||||
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
auto placeholder =
|
||||
ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input.shape));
|
||||
input.name = placeholder.node()->name();
|
||||
ops::Placeholder(root, DT_FLOAT, ops::Placeholder::Shape(input->shape));
|
||||
input->name = placeholder.node()->name();
|
||||
auto m = ops::MatMul(root, placeholder, constant_tensor);
|
||||
const string output_name = m.node()->name();
|
||||
*output_name = m.node()->name();
|
||||
TF_ASSERT_OK(root.ToGraphDef(graph_def));
|
||||
}
|
||||
|
||||
TEST(BenchmarkModelTest, InitializeAndRun) {
|
||||
const string dir = testing::TmpDir();
|
||||
const string filename_pb = io::JoinPath(dir, "graphdef.pb");
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
benchmark_model::InputLayerInfo input;
|
||||
string output_name;
|
||||
GraphDef graph_def;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
|
||||
CreateTestGraph(root, &input, &output_name, &graph_def);
|
||||
string graph_def_serialized;
|
||||
graph_def.SerializeToString(&graph_def_serialized);
|
||||
TF_ASSERT_OK(
|
||||
@ -69,5 +75,30 @@ TEST(BenchmarkModelTest, InitializeAndRun) {
|
||||
ASSERT_EQ(num_runs, 10);
|
||||
}
|
||||
|
||||
TEST(BenchmarkModeTest, TextProto) {
|
||||
const string dir = testing::TmpDir();
|
||||
const string filename_txt = io::JoinPath(dir, "graphdef.pb.txt");
|
||||
auto root = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
benchmark_model::InputLayerInfo input;
|
||||
string output_name;
|
||||
GraphDef graph_def;
|
||||
CreateTestGraph(root, &input, &output_name, &graph_def);
|
||||
TF_ASSERT_OK(WriteTextProto(Env::Default(), filename_txt, graph_def));
|
||||
|
||||
std::unique_ptr<Session> session;
|
||||
std::unique_ptr<GraphDef> loaded_graph_def;
|
||||
TF_ASSERT_OK(benchmark_model::InitializeSession(1, filename_txt, &session,
|
||||
&loaded_graph_def));
|
||||
std::unique_ptr<StatSummarizer> stats;
|
||||
stats.reset(new tensorflow::StatSummarizer(*(loaded_graph_def.get())));
|
||||
int64 time;
|
||||
int64 num_runs = 0;
|
||||
TF_ASSERT_OK(benchmark_model::TimeMultipleRuns(
|
||||
0.0, 10, 0.0, {input}, {output_name}, {}, session.get(), stats.get(),
|
||||
&time, &num_runs));
|
||||
ASSERT_EQ(num_runs, 10);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user