[Grappler] GraphView (immutable) benchmarks.
PiperOrigin-RevId: 238314797
This commit is contained in:
parent
3e377e72fb
commit
0993d774a8
@ -135,6 +135,35 @@ GraphDef CreateRandomGraph(int size) {
|
||||
return graph;
|
||||
}
|
||||
|
||||
GraphDef CreateFaninFanoutNodeGraph(int num_fanins, int num_fanouts) {
|
||||
GraphDef graph;
|
||||
|
||||
auto create_node = [](const string& name) {
|
||||
NodeDef node;
|
||||
node.set_name(name);
|
||||
return node;
|
||||
};
|
||||
|
||||
NodeDef node = create_node(/*name=*/"node");
|
||||
|
||||
for (int i = 0; i < num_fanins; ++i) {
|
||||
const string input_node_name = absl::StrFormat("in%05d", i);
|
||||
NodeDef input_node = create_node(/*name=*/input_node_name);
|
||||
*graph.add_node() = std::move(input_node);
|
||||
node.add_input(input_node_name);
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_fanouts; ++i) {
|
||||
NodeDef output_node = create_node(/*name=*/absl::StrFormat("out%05d", i));
|
||||
output_node.add_input(absl::StrCat(node.name(), ":", i));
|
||||
*graph.add_node() = std::move(output_node);
|
||||
}
|
||||
|
||||
*graph.add_node() = std::move(node);
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -117,8 +117,10 @@ tf_cc_test(
|
||||
":graph_view",
|
||||
":grappler_item",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -18,10 +18,12 @@ limitations under the License.
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/ops/parsing_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/graph/benchmark_testlib.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -289,6 +291,214 @@ TEST_F(GraphViewTest, GetRegularFaninPortOutOfBounds) {
|
||||
EXPECT_EQ(d_output_control, GraphView::OutputPort());
|
||||
}
|
||||
|
||||
static void BM_GraphViewConstruction(int iters, int num_nodes,
|
||||
int num_edges_per_node) {
|
||||
testing::StopTiming();
|
||||
const GraphDef graph_def =
|
||||
test::CreateGraphDef(num_nodes, num_edges_per_node);
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
GraphView graph_view(&graph_def);
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
BENCHMARK(BM_GraphViewConstruction)
|
||||
->ArgPair(10, 2)
|
||||
->ArgPair(100, 2)
|
||||
->ArgPair(1000, 2)
|
||||
->ArgPair(10000, 2)
|
||||
->ArgPair(25000, 2)
|
||||
->ArgPair(50000, 2)
|
||||
->ArgPair(100000, 2)
|
||||
->ArgPair(10, 4)
|
||||
->ArgPair(100, 4)
|
||||
->ArgPair(1000, 4)
|
||||
->ArgPair(10000, 4)
|
||||
->ArgPair(25000, 4)
|
||||
->ArgPair(50000, 4)
|
||||
->ArgPair(100000, 4)
|
||||
->ArgPair(10, 8)
|
||||
->ArgPair(100, 8)
|
||||
->ArgPair(1000, 8)
|
||||
->ArgPair(10000, 8)
|
||||
->ArgPair(25000, 8)
|
||||
->ArgPair(50000, 8)
|
||||
->ArgPair(100000, 8)
|
||||
->ArgPair(10, 16)
|
||||
->ArgPair(100, 16)
|
||||
->ArgPair(1000, 16)
|
||||
->ArgPair(10000, 16)
|
||||
->ArgPair(25000, 16)
|
||||
->ArgPair(50000, 16)
|
||||
->ArgPair(100000, 16);
|
||||
|
||||
static void BM_GraphViewGetNode(int iters, int num_nodes) {
|
||||
testing::StopTiming();
|
||||
const GraphDef graph_def =
|
||||
test::CreateGraphDef(num_nodes, /*num_edges_per_node=*/16);
|
||||
GraphView graph_view(&graph_def);
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
graph_view.GetNode("out");
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
BENCHMARK(BM_GraphViewGetNode)
|
||||
->Arg(10)
|
||||
->Arg(100)
|
||||
->Arg(1000)
|
||||
->Arg(10000)
|
||||
->Arg(25000)
|
||||
->Arg(50000)
|
||||
->Arg(100000);
|
||||
|
||||
#define RUN_FANIN_FANOUT_BENCHMARK(name) \
|
||||
BENCHMARK(name) \
|
||||
->ArgPair(10, 10) \
|
||||
->ArgPair(10, 100) \
|
||||
->ArgPair(10, 1000) \
|
||||
->ArgPair(10, 10000) \
|
||||
->ArgPair(10, 100000) \
|
||||
->ArgPair(100, 10) \
|
||||
->ArgPair(100, 100) \
|
||||
->ArgPair(100, 1000) \
|
||||
->ArgPair(100, 10000) \
|
||||
->ArgPair(100, 100000) \
|
||||
->ArgPair(1000, 10) \
|
||||
->ArgPair(1000, 100) \
|
||||
->ArgPair(1000, 1000) \
|
||||
->ArgPair(1000, 10000) \
|
||||
->ArgPair(1000, 100000) \
|
||||
->ArgPair(10000, 10) \
|
||||
->ArgPair(10000, 100) \
|
||||
->ArgPair(10000, 1000) \
|
||||
->ArgPair(10000, 10000) \
|
||||
->ArgPair(10000, 100000) \
|
||||
->ArgPair(100000, 10) \
|
||||
->ArgPair(100000, 100) \
|
||||
->ArgPair(100000, 1000) \
|
||||
->ArgPair(100000, 10000) \
|
||||
->ArgPair(100000, 100000);
|
||||
|
||||
static void BM_GraphViewGetFanout(int iters, int num_fanins, int num_fanouts) {
|
||||
testing::StopTiming();
|
||||
const GraphDef graph_def =
|
||||
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
|
||||
GraphView graph_view(&graph_def);
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
const NodeDef* node = graph_view.GetNode("node");
|
||||
graph_view.GetFanout({node, 0});
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanout);
|
||||
|
||||
static void BM_GraphViewGetFanin(int iters, int num_fanins, int num_fanouts) {
|
||||
testing::StopTiming();
|
||||
const GraphDef graph_def =
|
||||
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
|
||||
GraphView graph_view(&graph_def);
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
const NodeDef* node = graph_view.GetNode("node");
|
||||
graph_view.GetFanin({node, 0});
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanin);
|
||||
|
||||
static void BM_GraphViewGetRegularFanin(int iters, int num_fanins,
|
||||
int num_fanouts) {
|
||||
testing::StopTiming();
|
||||
const GraphDef graph_def =
|
||||
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
|
||||
GraphView graph_view(&graph_def);
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
const NodeDef* node = graph_view.GetNode("node");
|
||||
graph_view.GetRegularFanin({node, 0});
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanin);
|
||||
|
||||
static void BM_GraphViewGetFanouts(int iters, int num_fanins, int num_fanouts) {
|
||||
testing::StopTiming();
|
||||
const GraphDef graph_def =
|
||||
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
|
||||
GraphView graph_view(&graph_def);
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
const NodeDef* node = graph_view.GetNode("node");
|
||||
graph_view.GetFanouts(*node, /*include_controlled_nodes=*/false);
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanouts);
|
||||
|
||||
static void BM_GraphViewGetFanins(int iters, int num_fanins, int num_fanouts) {
|
||||
testing::StopTiming();
|
||||
const GraphDef graph_def =
|
||||
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
|
||||
GraphView graph_view(&graph_def);
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
const NodeDef* node = graph_view.GetNode("node");
|
||||
graph_view.GetFanins(*node, /*include_controlling_nodes=*/false);
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanins);
|
||||
|
||||
static void BM_GraphViewGetFanoutEdges(int iters, int num_fanins,
|
||||
int num_fanouts) {
|
||||
testing::StopTiming();
|
||||
const GraphDef graph_def =
|
||||
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
|
||||
GraphView graph_view(&graph_def);
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
const NodeDef* node = graph_view.GetNode("node");
|
||||
graph_view.GetFanoutEdges(*node, /*include_controlled_edges=*/false);
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanoutEdges);
|
||||
|
||||
static void BM_GraphViewGetFaninEdges(int iters, int num_fanins,
|
||||
int num_fanouts) {
|
||||
testing::StopTiming();
|
||||
const GraphDef graph_def =
|
||||
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
|
||||
GraphView graph_view(&graph_def);
|
||||
|
||||
testing::StartTiming();
|
||||
for (int i = 0; i < iters; ++i) {
|
||||
const NodeDef* node = graph_view.GetNode("node");
|
||||
graph_view.GetFaninEdges(*node, /*include_controlling_edges=*/false);
|
||||
}
|
||||
testing::StopTiming();
|
||||
}
|
||||
|
||||
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFaninEdges);
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user