From 58ef2a922f53bf7e464a0320c71896a111d2f272 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Mon, 18 Mar 2019 13:21:09 -0700 Subject: [PATCH] Extend graph creation benchmark util to add control dependencies. PiperOrigin-RevId: 239047668 --- tensorflow/core/graph/benchmark_testlib.h | 29 ++++++++++++++--- tensorflow/core/grappler/graph_view_test.cc | 35 ++++++++++++--------- 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/tensorflow/core/graph/benchmark_testlib.h b/tensorflow/core/graph/benchmark_testlib.h index 727cd07620d..dc5a21a85a2 100644 --- a/tensorflow/core/graph/benchmark_testlib.h +++ b/tensorflow/core/graph/benchmark_testlib.h @@ -135,7 +135,11 @@ GraphDef CreateRandomGraph(int size) { return graph; } -GraphDef CreateFaninFanoutNodeGraph(int num_fanins, int num_fanouts) { +GraphDef CreateFaninFanoutNodeGraph(int num_regular_fanins, + int num_regular_fanouts, + int num_controlling_fanins, + int num_controlled_fanouts, + bool fanout_unique_index) { GraphDef graph; auto create_node = [](const string& name) { @@ -146,16 +150,33 @@ GraphDef CreateFaninFanoutNodeGraph(int num_fanins, int num_fanouts) { NodeDef node = create_node(/*name=*/"node"); - for (int i = 0; i < num_fanins; ++i) { + for (int i = 0; i < num_regular_fanins; ++i) { const string input_node_name = absl::StrFormat("in%05d", i); NodeDef input_node = create_node(/*name=*/input_node_name); *graph.add_node() = std::move(input_node); node.add_input(input_node_name); } - for (int i = 0; i < num_fanouts; ++i) { + for (int i = 0; i < num_controlling_fanins; ++i) { + const string input_node_name = absl::StrFormat("control_in%05d", i); + NodeDef input_node = create_node(/*name=*/input_node_name); + *graph.add_node() = std::move(input_node); + node.add_input(absl::StrCat("^", input_node_name)); + } + + for (int i = 0; i < num_regular_fanouts; ++i) { NodeDef output_node = create_node(/*name=*/absl::StrFormat("out%05d", i)); - output_node.add_input(absl::StrCat(node.name(), ":", i)); + const string input_node_index = + fanout_unique_index ? absl::StrCat(node.name(), ":", i) : node.name(); + output_node.add_input(input_node_index); + *graph.add_node() = std::move(output_node); + } + + const string controlled_fanout_input = absl::StrCat("^", node.name()); + for (int i = 0; i < num_controlled_fanouts; ++i) { + NodeDef output_node = + create_node(/*name=*/absl::StrFormat("control_out%05d", i)); + output_node.add_input(controlled_fanout_input); *graph.add_node() = std::move(output_node); } diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index 0036719fc51..5b3e140f23d 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -386,8 +386,9 @@ BENCHMARK(BM_GraphViewGetNode) static void BM_GraphViewGetFanout(int iters, int num_fanins, int num_fanouts) { testing::StopTiming(); - const GraphDef graph_def = - test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + const GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); GraphView graph_view(&graph_def); testing::StartTiming(); @@ -402,8 +403,9 @@ RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanout); static void BM_GraphViewGetFanin(int iters, int num_fanins, int num_fanouts) { testing::StopTiming(); - const GraphDef graph_def = - test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + const GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); GraphView graph_view(&graph_def); testing::StartTiming(); @@ -419,8 +421,9 @@ RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanin); static void BM_GraphViewGetRegularFanin(int iters, int num_fanins, int num_fanouts) { testing::StopTiming(); - const GraphDef graph_def = - test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + const GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); GraphView graph_view(&graph_def); testing::StartTiming(); @@ -435,8 +438,9 @@ RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanin); static void BM_GraphViewGetFanouts(int iters, int num_fanins, int num_fanouts) { testing::StopTiming(); - const GraphDef graph_def = - test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + const GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); GraphView graph_view(&graph_def); testing::StartTiming(); @@ -451,8 +455,9 @@ RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanouts); static void BM_GraphViewGetFanins(int iters, int num_fanins, int num_fanouts) { testing::StopTiming(); - const GraphDef graph_def = - test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + const GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); GraphView graph_view(&graph_def); testing::StartTiming(); @@ -468,8 +473,9 @@ RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanins); static void BM_GraphViewGetFanoutEdges(int iters, int num_fanins, int num_fanouts) { testing::StopTiming(); - const GraphDef graph_def = - test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + const GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); GraphView graph_view(&graph_def); testing::StartTiming(); @@ -485,8 +491,9 @@ RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanoutEdges); static void BM_GraphViewGetFaninEdges(int iters, int num_fanins, int num_fanouts) { testing::StopTiming(); - const GraphDef graph_def = - test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + const GraphDef graph_def = test::CreateFaninFanoutNodeGraph( + num_fanins, num_fanouts, num_fanins, num_fanouts, + /*fanout_unique_index=*/true); GraphView graph_view(&graph_def); testing::StartTiming();