Extend graph creation benchmark util to add control dependencies.

PiperOrigin-RevId: 239047668
This commit is contained in:
Andy Ly 2019-03-18 13:21:09 -07:00 committed by TensorFlower Gardener
parent 47ea63e7ea
commit 58ef2a922f
2 changed files with 46 additions and 18 deletions

View File

@ -135,7 +135,11 @@ GraphDef CreateRandomGraph(int size) {
return graph;
}
GraphDef CreateFaninFanoutNodeGraph(int num_fanins, int num_fanouts) {
GraphDef CreateFaninFanoutNodeGraph(int num_regular_fanins,
int num_regular_fanouts,
int num_controlling_fanins,
int num_controlled_fanouts,
bool fanout_unique_index) {
GraphDef graph;
auto create_node = [](const string& name) {
@ -146,16 +150,33 @@ GraphDef CreateFaninFanoutNodeGraph(int num_fanins, int num_fanouts) {
NodeDef node = create_node(/*name=*/"node");
for (int i = 0; i < num_fanins; ++i) {
for (int i = 0; i < num_regular_fanins; ++i) {
const string input_node_name = absl::StrFormat("in%05d", i);
NodeDef input_node = create_node(/*name=*/input_node_name);
*graph.add_node() = std::move(input_node);
node.add_input(input_node_name);
}
for (int i = 0; i < num_fanouts; ++i) {
for (int i = 0; i < num_controlling_fanins; ++i) {
const string input_node_name = absl::StrFormat("control_in%05d", i);
NodeDef input_node = create_node(/*name=*/input_node_name);
*graph.add_node() = std::move(input_node);
node.add_input(absl::StrCat("^", input_node_name));
}
for (int i = 0; i < num_regular_fanouts; ++i) {
NodeDef output_node = create_node(/*name=*/absl::StrFormat("out%05d", i));
output_node.add_input(absl::StrCat(node.name(), ":", i));
const string input_node_index =
fanout_unique_index ? absl::StrCat(node.name(), ":", i) : node.name();
output_node.add_input(input_node_index);
*graph.add_node() = std::move(output_node);
}
const string controlled_fanout_input = absl::StrCat("^", node.name());
for (int i = 0; i < num_controlled_fanouts; ++i) {
NodeDef output_node =
create_node(/*name=*/absl::StrFormat("control_out%05d", i));
output_node.add_input(controlled_fanout_input);
*graph.add_node() = std::move(output_node);
}

View File

@ -386,8 +386,9 @@ BENCHMARK(BM_GraphViewGetNode)
static void BM_GraphViewGetFanout(int iters, int num_fanins, int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
@ -402,8 +403,9 @@ RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanout);
static void BM_GraphViewGetFanin(int iters, int num_fanins, int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
@ -419,8 +421,9 @@ RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanin);
static void BM_GraphViewGetRegularFanin(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
@ -435,8 +438,9 @@ RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanin);
static void BM_GraphViewGetFanouts(int iters, int num_fanins, int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
@ -451,8 +455,9 @@ RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanouts);
static void BM_GraphViewGetFanins(int iters, int num_fanins, int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
@ -468,8 +473,9 @@ RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanins);
static void BM_GraphViewGetFanoutEdges(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
@ -485,8 +491,9 @@ RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanoutEdges);
static void BM_GraphViewGetFaninEdges(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
const GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins, num_fanouts);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();