From 0993d774a8a10e7194e5ea4dbfdf7836e23eb326 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Wed, 13 Mar 2019 14:50:43 -0700 Subject: [PATCH] [Grappler] GraphView (immutable) benchmarks. PiperOrigin-RevId: 238314797 --- tensorflow/core/graph/benchmark_testlib.h | 29 +++ tensorflow/core/grappler/BUILD | 2 + tensorflow/core/grappler/graph_view_test.cc | 210 ++++++++++++++++++++ 3 files changed, 241 insertions(+) diff --git a/tensorflow/core/graph/benchmark_testlib.h b/tensorflow/core/graph/benchmark_testlib.h index 4322352cb47..727cd07620d 100644 --- a/tensorflow/core/graph/benchmark_testlib.h +++ b/tensorflow/core/graph/benchmark_testlib.h @@ -135,6 +135,35 @@ GraphDef CreateRandomGraph(int size) { return graph; } +GraphDef CreateFaninFanoutNodeGraph(int num_fanins, int num_fanouts) { + GraphDef graph; + + auto create_node = [](const string& name) { + NodeDef node; + node.set_name(name); + return node; + }; + + NodeDef node = create_node(/*name=*/"node"); + + for (int i = 0; i < num_fanins; ++i) { + const string input_node_name = absl::StrFormat("in%05d", i); + NodeDef input_node = create_node(/*name=*/input_node_name); + *graph.add_node() = std::move(input_node); + node.add_input(input_node_name); + } + + for (int i = 0; i < num_fanouts; ++i) { + NodeDef output_node = create_node(/*name=*/absl::StrFormat("out%05d", i)); + output_node.add_input(absl::StrCat(node.name(), ":", i)); + *graph.add_node() = std::move(output_node); + } + + *graph.add_node() = std::move(node); + + return graph; +} + } // namespace test } // namespace tensorflow diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 77307708fab..b5f223facbd 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -117,8 +117,10 @@ tf_cc_test( ":graph_view", ":grappler_item", "//tensorflow/cc:cc_ops", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index 839057065b4..0036719fc51 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -18,10 +18,12 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/cc/ops/parsing_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/graph/benchmark_testlib.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace tensorflow { namespace grappler { @@ -289,6 +291,214 @@ TEST_F(GraphViewTest, GetRegularFaninPortOutOfBounds) { EXPECT_EQ(d_output_control, GraphView::OutputPort()); } +static void BM_GraphViewConstruction(int iters, int num_nodes, + int num_edges_per_node) { + testing::StopTiming(); + const GraphDef graph_def = + test::CreateGraphDef(num_nodes, num_edges_per_node); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + GraphView graph_view(&graph_def); + } + testing::StopTiming(); +} + +BENCHMARK(BM_GraphViewConstruction) + ->ArgPair(10, 2) + ->ArgPair(100, 2) + ->ArgPair(1000, 2) + ->ArgPair(10000, 2) + ->ArgPair(25000, 2) + ->ArgPair(50000, 2) + ->ArgPair(100000, 2) + ->ArgPair(10, 4) + ->ArgPair(100, 4) + ->ArgPair(1000, 4) + ->ArgPair(10000, 4) + ->ArgPair(25000, 4) + ->ArgPair(50000, 4) + ->ArgPair(100000, 4) + ->ArgPair(10, 8) + ->ArgPair(100, 8) + ->ArgPair(1000, 8) + ->ArgPair(10000, 8) + ->ArgPair(25000, 8) + ->ArgPair(50000, 8) + ->ArgPair(100000, 8) + ->ArgPair(10, 16) + ->ArgPair(100, 16) + ->ArgPair(1000, 16) + ->ArgPair(10000, 16) + ->ArgPair(25000, 16) + ->ArgPair(50000, 16) + ->ArgPair(100000, 16); + +static void BM_GraphViewGetNode(int iters, int num_nodes) { + testing::StopTiming(); + const GraphDef graph_def = + test::CreateGraphDef(num_nodes, /*num_edges_per_node=*/16); + GraphView graph_view(&graph_def); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + graph_view.GetNode("out"); + } + testing::StopTiming(); +} + +BENCHMARK(BM_GraphViewGetNode) + ->Arg(10) + ->Arg(100) + ->Arg(1000) + ->Arg(10000) + ->Arg(25000) + ->Arg(50000) + ->Arg(100000); + +#define RUN_FANIN_FANOUT_BENCHMARK(name) \ + BENCHMARK(name) \ + ->ArgPair(10, 10) \ + ->ArgPair(10, 100) \ + ->ArgPair(10, 1000) \ + ->ArgPair(10, 10000) \ + ->ArgPair(10, 100000) \ + ->ArgPair(100, 10) \ + ->ArgPair(100, 100) \ + ->ArgPair(100, 1000) \ + ->ArgPair(100, 10000) \ + ->ArgPair(100, 100000) \ + ->ArgPair(1000, 10) \ + ->ArgPair(1000, 100) \ + ->ArgPair(1000, 1000) \ + ->ArgPair(1000, 10000) \ + ->ArgPair(1000, 100000) \ + ->ArgPair(10000, 10) \ + ->ArgPair(10000, 100) \ + ->ArgPair(10000, 1000) \ + ->ArgPair(10000, 10000) \ + ->ArgPair(10000, 100000) \ + ->ArgPair(100000, 10) \ + ->ArgPair(100000, 100) \ + ->ArgPair(100000, 1000) \ + ->ArgPair(100000, 10000) \ + ->ArgPair(100000, 100000); + +static void BM_GraphViewGetFanout(int iters, int num_fanins, int num_fanouts) { + testing::StopTiming(); + const GraphDef graph_def = + test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + GraphView graph_view(&graph_def); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + const NodeDef* node = graph_view.GetNode("node"); + graph_view.GetFanout({node, 0}); + } + testing::StopTiming(); +} + +RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanout); + +static void BM_GraphViewGetFanin(int iters, int num_fanins, int num_fanouts) { + testing::StopTiming(); + const GraphDef graph_def = + test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + GraphView graph_view(&graph_def); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + const NodeDef* node = graph_view.GetNode("node"); + graph_view.GetFanin({node, 0}); + } + testing::StopTiming(); +} + +RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanin); + +static void BM_GraphViewGetRegularFanin(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + const GraphDef graph_def = + test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + GraphView graph_view(&graph_def); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + const NodeDef* node = graph_view.GetNode("node"); + graph_view.GetRegularFanin({node, 0}); + } + testing::StopTiming(); +} + +RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanin); + +static void BM_GraphViewGetFanouts(int iters, int num_fanins, int num_fanouts) { + testing::StopTiming(); + const GraphDef graph_def = + test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + GraphView graph_view(&graph_def); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + const NodeDef* node = graph_view.GetNode("node"); + graph_view.GetFanouts(*node, /*include_controlled_nodes=*/false); + } + testing::StopTiming(); +} + +RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanouts); + +static void BM_GraphViewGetFanins(int iters, int num_fanins, int num_fanouts) { + testing::StopTiming(); + const GraphDef graph_def = + test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + GraphView graph_view(&graph_def); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + const NodeDef* node = graph_view.GetNode("node"); + graph_view.GetFanins(*node, /*include_controlling_nodes=*/false); + } + testing::StopTiming(); +} + +RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanins); + +static void BM_GraphViewGetFanoutEdges(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + const GraphDef graph_def = + test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + GraphView graph_view(&graph_def); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + const NodeDef* node = graph_view.GetNode("node"); + graph_view.GetFanoutEdges(*node, /*include_controlled_edges=*/false); + } + testing::StopTiming(); +} + +RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanoutEdges); + +static void BM_GraphViewGetFaninEdges(int iters, int num_fanins, + int num_fanouts) { + testing::StopTiming(); + const GraphDef graph_def = + test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts); + GraphView graph_view(&graph_def); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + const NodeDef* node = graph_view.GetNode("node"); + graph_view.GetFaninEdges(*node, /*include_controlling_edges=*/false); + } + testing::StopTiming(); +} + +RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFaninEdges); + } // namespace } // namespace grappler } // namespace tensorflow