[Grappler] GraphView (immutable) benchmarks.

PiperOrigin-RevId: 238314797
This commit is contained in:
Andy Ly 2019-03-13 14:50:43 -07:00 committed by TensorFlower Gardener
parent 3e377e72fb
commit 0993d774a8
3 changed files with 241 additions and 0 deletions

View File

@ -135,6 +135,35 @@ GraphDef CreateRandomGraph(int size) {
return graph; return graph;
} }
GraphDef CreateFaninFanoutNodeGraph(int num_fanins, int num_fanouts) {
GraphDef graph;
auto create_node = [](const string& name) {
NodeDef node;
node.set_name(name);
return node;
};
NodeDef node = create_node(/*name=*/"node");
for (int i = 0; i < num_fanins; ++i) {
const string input_node_name = absl::StrFormat("in%05d", i);
NodeDef input_node = create_node(/*name=*/input_node_name);
*graph.add_node() = std::move(input_node);
node.add_input(input_node_name);
}
for (int i = 0; i < num_fanouts; ++i) {
NodeDef output_node = create_node(/*name=*/absl::StrFormat("out%05d", i));
output_node.add_input(absl::StrCat(node.name(), ":", i));
*graph.add_node() = std::move(output_node);
}
*graph.add_node() = std::move(node);
return graph;
}
} // namespace test } // namespace test
} // namespace tensorflow } // namespace tensorflow

View File

@ -117,8 +117,10 @@ tf_cc_test(
":graph_view", ":graph_view",
":grappler_item", ":grappler_item",
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
"//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",

View File

@ -18,10 +18,12 @@ limitations under the License.
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "tensorflow/cc/ops/parsing_ops.h" #include "tensorflow/cc/ops/parsing_ops.h"
#include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/graph/benchmark_testlib.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow { namespace tensorflow {
namespace grappler { namespace grappler {
@ -289,6 +291,214 @@ TEST_F(GraphViewTest, GetRegularFaninPortOutOfBounds) {
EXPECT_EQ(d_output_control, GraphView::OutputPort()); EXPECT_EQ(d_output_control, GraphView::OutputPort());
} }
static void BM_GraphViewConstruction(int iters, int num_nodes,
int num_edges_per_node) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
GraphView graph_view(&graph_def);
}
testing::StopTiming();
}
BENCHMARK(BM_GraphViewConstruction)
->ArgPair(10, 2)
->ArgPair(100, 2)
->ArgPair(1000, 2)
->ArgPair(10000, 2)
->ArgPair(25000, 2)
->ArgPair(50000, 2)
->ArgPair(100000, 2)
->ArgPair(10, 4)
->ArgPair(100, 4)
->ArgPair(1000, 4)
->ArgPair(10000, 4)
->ArgPair(25000, 4)
->ArgPair(50000, 4)
->ArgPair(100000, 4)
->ArgPair(10, 8)
->ArgPair(100, 8)
->ArgPair(1000, 8)
->ArgPair(10000, 8)
->ArgPair(25000, 8)
->ArgPair(50000, 8)
->ArgPair(100000, 8)
->ArgPair(10, 16)
->ArgPair(100, 16)
->ArgPair(1000, 16)
->ArgPair(10000, 16)
->ArgPair(25000, 16)
->ArgPair(50000, 16)
->ArgPair(100000, 16);
static void BM_GraphViewGetNode(int iters, int num_nodes) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateGraphDef(num_nodes, /*num_edges_per_node=*/16);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
graph_view.GetNode("out");
}
testing::StopTiming();
}
BENCHMARK(BM_GraphViewGetNode)
->Arg(10)
->Arg(100)
->Arg(1000)
->Arg(10000)
->Arg(25000)
->Arg(50000)
->Arg(100000);
#define RUN_FANIN_FANOUT_BENCHMARK(name) \
BENCHMARK(name) \
->ArgPair(10, 10) \
->ArgPair(10, 100) \
->ArgPair(10, 1000) \
->ArgPair(10, 10000) \
->ArgPair(10, 100000) \
->ArgPair(100, 10) \
->ArgPair(100, 100) \
->ArgPair(100, 1000) \
->ArgPair(100, 10000) \
->ArgPair(100, 100000) \
->ArgPair(1000, 10) \
->ArgPair(1000, 100) \
->ArgPair(1000, 1000) \
->ArgPair(1000, 10000) \
->ArgPair(1000, 100000) \
->ArgPair(10000, 10) \
->ArgPair(10000, 100) \
->ArgPair(10000, 1000) \
->ArgPair(10000, 10000) \
->ArgPair(10000, 100000) \
->ArgPair(100000, 10) \
->ArgPair(100000, 100) \
->ArgPair(100000, 1000) \
->ArgPair(100000, 10000) \
->ArgPair(100000, 100000);
static void BM_GraphViewGetFanout(int iters, int num_fanins, int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetFanout({node, 0});
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanout);
static void BM_GraphViewGetFanin(int iters, int num_fanins, int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetFanin({node, 0});
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanin);
static void BM_GraphViewGetRegularFanin(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetRegularFanin({node, 0});
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanin);
static void BM_GraphViewGetFanouts(int iters, int num_fanins, int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetFanouts(*node, /*include_controlled_nodes=*/false);
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanouts);
static void BM_GraphViewGetFanins(int iters, int num_fanins, int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetFanins(*node, /*include_controlling_nodes=*/false);
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanins);
static void BM_GraphViewGetFanoutEdges(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetFanoutEdges(*node, /*include_controlled_edges=*/false);
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanoutEdges);
static void BM_GraphViewGetFaninEdges(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetFaninEdges(*node, /*include_controlling_edges=*/false);
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFaninEdges);
} // namespace } // namespace
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow