Internal tests cleanup

PiperOrigin-RevId: 339366980
Change-Id: Ic12014b8db7717d6873a47143d9c9a4098b0f4be
This commit is contained in:
A. Unique TensorFlower 2020-10-27 17:44:33 -07:00 committed by TensorFlower Gardener
parent f0e5b4e28d
commit 22380724dd
4 changed files with 228 additions and 293 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/cc/ops/parsing_ops.h"
@ -291,17 +292,16 @@ TEST_F(GraphViewTest, GetRegularFaninPortOutOfBounds) {
EXPECT_EQ(d_output_control, GraphView::OutputPort());
}
static void BM_GraphViewConstruction(int iters, int num_nodes,
int num_edges_per_node) {
testing::StopTiming();
void BM_GraphViewConstruction(::testing::benchmark::State& state) {
const int num_nodes = state.range(0);
const int num_edges_per_node = state.range(1);
const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
GraphView graph_view(&graph_def);
}
testing::StopTiming();
}
BENCHMARK(BM_GraphViewConstruction)
@ -334,17 +334,16 @@ BENCHMARK(BM_GraphViewConstruction)
->ArgPair(50000, 16)
->ArgPair(100000, 16);
static void BM_GraphViewGetNode(int iters, int num_nodes) {
testing::StopTiming();
void BM_GraphViewGetNode(::testing::benchmark::State& state) {
const int num_nodes = state.range(0);
const GraphDef graph_def =
test::CreateGraphDef(num_nodes, /*num_edges_per_node=*/16);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
graph_view.GetNode("out");
}
testing::StopTiming();
}
BENCHMARK(BM_GraphViewGetNode)
@ -384,124 +383,121 @@ BENCHMARK(BM_GraphViewGetNode)
->ArgPair(100000, 10000) \
->ArgPair(100000, 100000);
static void BM_GraphViewGetFanout(int iters, int num_fanins, int num_fanouts) {
testing::StopTiming();
void BM_GraphViewGetFanout(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetFanout({node, 0});
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanout);
static void BM_GraphViewGetFanin(int iters, int num_fanins, int num_fanouts) {
testing::StopTiming();
void BM_GraphViewGetFanin(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetFanin({node, 0});
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanin);
static void BM_GraphViewGetRegularFanin(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
void BM_GraphViewGetRegularFanin(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetRegularFanin({node, 0});
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanin);
static void BM_GraphViewGetFanouts(int iters, int num_fanins, int num_fanouts) {
testing::StopTiming();
void BM_GraphViewGetFanouts(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetFanouts(*node, /*include_controlled_nodes=*/false);
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanouts);
static void BM_GraphViewGetFanins(int iters, int num_fanins, int num_fanouts) {
testing::StopTiming();
void BM_GraphViewGetFanins(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetFanins(*node, /*include_controlling_nodes=*/false);
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanins);
static void BM_GraphViewGetFanoutEdges(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
void BM_GraphViewGetFanoutEdges(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetFanoutEdges(*node, /*include_controlled_edges=*/false);
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFanoutEdges);
static void BM_GraphViewGetFaninEdges(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
void BM_GraphViewGetFaninEdges(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
const GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
GraphView graph_view(&graph_def);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
const NodeDef* node = graph_view.GetNode("node");
graph_view.GetFaninEdges(*node, /*include_controlling_edges=*/false);
}
testing::StopTiming();
}
RUN_FANIN_FANOUT_BENCHMARK(BM_GraphViewGetFaninEdges);

View File

@ -2391,47 +2391,42 @@ TEST_F(TopologicalSortTest, PushVisitedNodes) {
->ArgPair(100000, 16);
template <typename GraphViewT>
static void BM_GraphViewTConstruction(int iters, int num_nodes,
int num_edges_per_node) {
testing::StopTiming();
void BM_GraphViewTConstruction(::testing::benchmark::State& state) {
const int num_nodes = state.range(0);
const int num_edges_per_node = state.range(1);
GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
Status s;
GraphViewT graph_view(&graph_def, &s);
}
testing::StopTiming();
}
static void BM_GraphViewConstruction(int iters, int num_nodes,
int num_edges_per_node) {
BM_GraphViewTConstruction<GraphView>(iters, num_nodes, num_edges_per_node);
void BM_GraphViewConstruction(::testing::benchmark::State& state) {
BM_GraphViewTConstruction<GraphView>(state);
}
static void BM_MutableGraphViewConstruction(int iters, int num_nodes,
int num_edges_per_node) {
BM_GraphViewTConstruction<MutableGraphView>(iters, num_nodes,
num_edges_per_node);
void BM_MutableGraphViewConstruction(::testing::benchmark::State& state) {
BM_GraphViewTConstruction<MutableGraphView>(state);
}
static void BM_MutableGraphViewClearAttrs(int iters, int num_nodes,
int num_edges_per_node) {
testing::StopTiming();
void BM_MutableGraphViewClearAttrs(::testing::benchmark::State& state) {
const int num_nodes = state.range(0);
const int num_edges_per_node = state.range(1);
GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node);
Status s;
MutableGraphView graph_view(&graph_def, &s);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
utils::Mutation* mutation = graph_view.GetMutationBuilder();
for (int j = 0; j < num_nodes; ++j) {
mutation->RemoveNodeAttr(graph_view.GetNode(j), "_some_random_attr");
}
s = mutation->Apply();
}
testing::StopTiming();
}
RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_GraphViewConstruction);
@ -2449,58 +2444,54 @@ RUN_NUM_NODE_NUM_EDGE_BENCHMARK(BM_MutableGraphViewClearAttrs);
->Arg(100000);
template <typename GraphViewT>
static void BM_GraphViewTConstructionWithControlDependencies(
int iters, int num_fanins_fanouts) {
testing::StopTiming();
void BM_GraphViewTConstructionWithControlDependencies(
::testing::benchmark::State& state) {
const int num_fanins_fanouts = state.range(0);
GraphDef graph_def =
test::CreateFaninFanoutNodeGraph(num_fanins_fanouts, num_fanins_fanouts,
num_fanins_fanouts, num_fanins_fanouts,
/*fanout_unique_index=*/true);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
Status s;
GraphViewT graph_view(&graph_def, &s);
}
testing::StopTiming();
}
static void BM_GraphViewConstructionWithControlDependencies(
int iters, int num_fanins_fanouts) {
BM_GraphViewTConstructionWithControlDependencies<GraphView>(
iters, num_fanins_fanouts);
void BM_GraphViewConstructionWithControlDependencies(
::testing::benchmark::State& state) {
BM_GraphViewTConstructionWithControlDependencies<GraphView>(state);
}
static void BM_MutableGraphViewConstructionWithControlDependencies(
int iters, int num_fanins_fanouts) {
BM_GraphViewTConstructionWithControlDependencies<MutableGraphView>(
iters, num_fanins_fanouts);
void BM_MutableGraphViewConstructionWithControlDependencies(
::testing::benchmark::State& state) {
BM_GraphViewTConstructionWithControlDependencies<MutableGraphView>(state);
}
RUN_NUM_NODE_BENCHMARK(BM_GraphViewConstructionWithControlDependencies);
RUN_NUM_NODE_BENCHMARK(BM_MutableGraphViewConstructionWithControlDependencies);
template <typename GraphViewT>
static void BM_GraphViewTGetNode(int iters, int num_nodes) {
testing::StopTiming();
void BM_GraphViewTGetNode(::testing::benchmark::State& state) {
const int num_nodes = state.range(0);
GraphDef graph_def =
test::CreateGraphDef(num_nodes, /*num_edges_per_node=*/16);
Status s;
GraphViewT graph_view(&graph_def, &s);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
graph_view.GetNode("out");
}
testing::StopTiming();
}
static void BM_GraphViewGetNode(int iters, int num_nodes) {
BM_GraphViewTGetNode<GraphView>(iters, num_nodes);
void BM_GraphViewGetNode(::testing::benchmark::State& state) {
BM_GraphViewTGetNode<GraphView>(state);
}
static void BM_MutableGraphViewGetNode(int iters, int num_nodes) {
BM_GraphViewTGetNode<MutableGraphView>(iters, num_nodes);
void BM_MutableGraphViewGetNode(::testing::benchmark::State& state) {
BM_GraphViewTGetNode<MutableGraphView>(state);
}
RUN_NUM_NODE_BENCHMARK(BM_GraphViewGetNode);
@ -2535,201 +2526,180 @@ RUN_NUM_NODE_BENCHMARK(BM_MutableGraphViewGetNode);
->ArgPair(100000, 100000);
template <typename GraphViewT>
static void BM_GraphViewTGetRegularFanin(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
void BM_GraphViewTGetRegularFanin(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
Status s;
GraphViewT graph_view(&graph_def, &s);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
auto* node = graph_view.GetNode("node");
node->GetRegularFanin(0);
}
testing::StopTiming();
}
static void BM_GraphViewGetRegularFanin(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTGetRegularFanin<GraphView>(iters, num_fanins, num_fanouts);
void BM_GraphViewGetRegularFanin(::testing::benchmark::State& state) {
BM_GraphViewTGetRegularFanin<GraphView>(state);
}
static void BM_MutableGraphViewGetRegularFanin(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTGetRegularFanin<MutableGraphView>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewGetRegularFanin(::testing::benchmark::State& state) {
BM_GraphViewTGetRegularFanin<MutableGraphView>(state);
}
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanin);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanin);
template <typename GraphViewT>
static void BM_GraphViewTGetRegularFanout(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
void BM_GraphViewTGetRegularFanout(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
Status s;
GraphViewT graph_view(&graph_def, &s);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
auto* node = graph_view.GetNode("node");
node->GetRegularFanout(0);
}
testing::StopTiming();
}
static void BM_GraphViewGetRegularFanout(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTGetRegularFanout<GraphView>(iters, num_fanins, num_fanouts);
void BM_GraphViewGetRegularFanout(::testing::benchmark::State& state) {
BM_GraphViewTGetRegularFanout<GraphView>(state);
}
static void BM_MutableGraphViewGetRegularFanout(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTGetRegularFanout<MutableGraphView>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewGetRegularFanout(::testing::benchmark::State& state) {
BM_GraphViewTGetRegularFanout<MutableGraphView>(state);
}
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanout);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanout);
template <typename GraphViewT>
static void BM_GraphViewTGetRegularFanins(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
void BM_GraphViewTGetRegularFanins(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
Status s;
GraphViewT graph_view(&graph_def, &s);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
auto* node = graph_view.GetNode("node");
node->GetRegularFanins();
}
testing::StopTiming();
}
static void BM_GraphViewGetRegularFanins(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTGetRegularFanins<GraphView>(iters, num_fanins, num_fanouts);
void BM_GraphViewGetRegularFanins(::testing::benchmark::State& state) {
BM_GraphViewTGetRegularFanins<GraphView>(state);
}
static void BM_MutableGraphViewGetRegularFanins(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTGetRegularFanins<MutableGraphView>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewGetRegularFanins(::testing::benchmark::State& state) {
BM_GraphViewTGetRegularFanins<MutableGraphView>(state);
}
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanins);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanins);
template <typename GraphViewT>
static void BM_GraphViewTGetRegularFanouts(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
void BM_GraphViewTGetRegularFanouts(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
Status s;
GraphViewT graph_view(&graph_def, &s);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
auto* node = graph_view.GetNode("node");
node->GetRegularFanouts();
}
testing::StopTiming();
}
static void BM_GraphViewGetRegularFanouts(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTGetRegularFanouts<GraphView>(iters, num_fanins, num_fanouts);
void BM_GraphViewGetRegularFanouts(::testing::benchmark::State& state) {
BM_GraphViewTGetRegularFanouts<GraphView>(state);
}
static void BM_MutableGraphViewGetRegularFanouts(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTGetRegularFanouts<MutableGraphView>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewGetRegularFanouts(::testing::benchmark::State& state) {
BM_GraphViewTGetRegularFanouts<MutableGraphView>(state);
}
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetRegularFanouts);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetRegularFanouts);
template <typename GraphViewT>
static void BM_GraphViewTGetControllingFanins(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
void BM_GraphViewTGetControllingFanins(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
Status s;
GraphViewT graph_view(&graph_def, &s);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
auto* node = graph_view.GetNode("node");
node->GetControllingFanins();
}
testing::StopTiming();
}
static void BM_GraphViewGetControllingFanins(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTGetControllingFanins<GraphView>(iters, num_fanins, num_fanouts);
void BM_GraphViewGetControllingFanins(::testing::benchmark::State& state) {
BM_GraphViewTGetControllingFanins<GraphView>(state);
}
static void BM_MutableGraphViewGetControllingFanins(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTGetControllingFanins<MutableGraphView>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewGetControllingFanins(
::testing::benchmark::State& state) {
BM_GraphViewTGetControllingFanins<MutableGraphView>(state);
}
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetControllingFanins);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetControllingFanins);
template <typename GraphViewT>
static void BM_GraphViewTGetControlledFanouts(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
void BM_GraphViewTGetControlledFanouts(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
Status s;
GraphViewT graph_view(&graph_def, &s);
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
auto* node = graph_view.GetNode("node");
node->GetControlledFanouts();
}
testing::StopTiming();
}
static void BM_GraphViewGetControlledFanouts(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTGetControlledFanouts<GraphView>(iters, num_fanins, num_fanouts);
void BM_GraphViewGetControlledFanouts(::testing::benchmark::State& state) {
BM_GraphViewTGetControlledFanouts<GraphView>(state);
}
static void BM_MutableGraphViewGetControlledFanouts(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTGetControlledFanouts<MutableGraphView>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewGetControlledFanouts(
::testing::benchmark::State& state) {
BM_GraphViewTGetControlledFanouts<MutableGraphView>(state);
}
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewGetControlledFanouts);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewGetControlledFanouts);
template <typename GraphViewT, bool IsLast>
inline static void BM_GraphViewTHasRegularFanin(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
inline void BM_GraphViewTHasRegularFanin(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, /*num_controlling_fanins=*/0,
/*num_controlled_fanouts=*/0, /*fanout_unique_index=*/false);
@ -2739,34 +2709,27 @@ inline static void BM_GraphViewTHasRegularFanin(int iters, int num_fanins,
auto* node = graph_view.GetNode(absl::StrFormat("out%05d", index));
auto* fanin = graph_view.GetNode("node");
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
node->HasFanin({&graph_view, fanin->node_index(), 0});
}
testing::StopTiming();
}
static void BM_GraphViewHasRegularFaninFirst(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTHasRegularFanin<GraphView, false>(iters, num_fanins,
num_fanouts);
void BM_GraphViewHasRegularFaninFirst(::testing::benchmark::State& state) {
BM_GraphViewTHasRegularFanin<GraphView, false>(state);
}
static void BM_GraphViewHasRegularFaninLast(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTHasRegularFanin<GraphView, true>(iters, num_fanins, num_fanouts);
void BM_GraphViewHasRegularFaninLast(::testing::benchmark::State& state) {
BM_GraphViewTHasRegularFanin<GraphView, true>(state);
}
static void BM_MutableGraphViewHasRegularFaninFirst(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTHasRegularFanin<MutableGraphView, false>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewHasRegularFaninFirst(
::testing::benchmark::State& state) {
BM_GraphViewTHasRegularFanin<MutableGraphView, false>(state);
}
static void BM_MutableGraphViewHasRegularFaninLast(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTHasRegularFanin<MutableGraphView, true>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewHasRegularFaninLast(
::testing::benchmark::State& state) {
BM_GraphViewTHasRegularFanin<MutableGraphView, true>(state);
}
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFaninFirst);
@ -2775,9 +2738,11 @@ RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFaninFirst);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFaninLast);
template <typename GraphViewT, bool IsLast>
inline static void BM_GraphViewTHasControllingFanin(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
inline void BM_GraphViewTHasControllingFanin(
::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/true);
@ -2787,37 +2752,27 @@ inline static void BM_GraphViewTHasControllingFanin(int iters, int num_fanins,
auto* node = graph_view.GetNode(absl::StrFormat("control_out%05d", index));
auto* fanin = graph_view.GetNode("node");
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
node->HasFanin({&graph_view, fanin->node_index(), Graph::kControlSlot});
}
testing::StopTiming();
}
static void BM_GraphViewHasControllingFaninFirst(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTHasControllingFanin<GraphView, false>(iters, num_fanins,
num_fanouts);
void BM_GraphViewHasControllingFaninFirst(::testing::benchmark::State& state) {
BM_GraphViewTHasControllingFanin<GraphView, false>(state);
}
static void BM_GraphViewHasControllingFaninLast(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTHasControllingFanin<GraphView, true>(iters, num_fanins,
num_fanouts);
void BM_GraphViewHasControllingFaninLast(::testing::benchmark::State& state) {
BM_GraphViewTHasControllingFanin<GraphView, true>(state);
}
static void BM_MutableGraphViewHasControllingFaninFirst(int iters,
int num_fanins,
int num_fanouts) {
BM_GraphViewTHasControllingFanin<MutableGraphView, false>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewHasControllingFaninFirst(
::testing::benchmark::State& state) {
BM_GraphViewTHasControllingFanin<MutableGraphView, false>(state);
}
static void BM_MutableGraphViewHasControllingFaninLast(int iters,
int num_fanins,
int num_fanouts) {
BM_GraphViewTHasControllingFanin<MutableGraphView, true>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewHasControllingFaninLast(
::testing::benchmark::State& state) {
BM_GraphViewTHasControllingFanin<MutableGraphView, true>(state);
}
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControllingFaninFirst);
@ -2826,9 +2781,10 @@ RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControllingFaninFirst);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControllingFaninLast);
template <typename GraphViewT, bool IsLast>
inline static void BM_GraphViewTHasRegularFanout(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
inline void BM_GraphViewTHasRegularFanout(::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, /*num_controlling_fanins=*/0,
/*num_controlled_fanouts=*/0, /*fanout_unique_index=*/false);
@ -2838,35 +2794,27 @@ inline static void BM_GraphViewTHasRegularFanout(int iters, int num_fanins,
auto* node = graph_view.GetNode(absl::StrFormat("in%05d", index));
auto* fanout = graph_view.GetNode("node");
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
node->HasFanout({&graph_view, fanout->node_index(), index});
}
testing::StopTiming();
}
static void BM_GraphViewHasRegularFanoutFirst(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTHasRegularFanout<GraphView, false>(iters, num_fanins,
num_fanouts);
void BM_GraphViewHasRegularFanoutFirst(::testing::benchmark::State& state) {
BM_GraphViewTHasRegularFanout<GraphView, false>(state);
}
static void BM_GraphViewHasRegularFanoutLast(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTHasRegularFanout<GraphView, true>(iters, num_fanins,
num_fanouts);
void BM_GraphViewHasRegularFanoutLast(::testing::benchmark::State& state) {
BM_GraphViewTHasRegularFanout<GraphView, true>(state);
}
static void BM_MutableGraphViewHasRegularFanoutFirst(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTHasRegularFanout<MutableGraphView, false>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewHasRegularFanoutFirst(
::testing::benchmark::State& state) {
BM_GraphViewTHasRegularFanout<MutableGraphView, false>(state);
}
static void BM_MutableGraphViewHasRegularFanoutLast(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTHasRegularFanout<MutableGraphView, true>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewHasRegularFanoutLast(
::testing::benchmark::State& state) {
BM_GraphViewTHasRegularFanout<MutableGraphView, true>(state);
}
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasRegularFanoutFirst);
@ -2875,9 +2823,11 @@ RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFanoutFirst);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasRegularFanoutLast);
template <typename GraphViewT, bool IsLast>
inline static void BM_GraphViewTHasControlledFanout(int iters, int num_fanins,
int num_fanouts) {
testing::StopTiming();
inline void BM_GraphViewTHasControlledFanout(
::testing::benchmark::State& state) {
const int num_fanins = state.range(0);
const int num_fanouts = state.range(1);
GraphDef graph_def = test::CreateFaninFanoutNodeGraph(
num_fanins, num_fanouts, num_fanins, num_fanouts,
/*fanout_unique_index=*/false);
@ -2887,37 +2837,27 @@ inline static void BM_GraphViewTHasControlledFanout(int iters, int num_fanins,
auto* node = graph_view.GetNode(absl::StrFormat("control_in%05d", index));
auto* fanout = graph_view.GetNode("node");
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto i : state) {
node->HasFanout({&graph_view, fanout->node_index(), Graph::kControlSlot});
}
testing::StopTiming();
}
static void BM_GraphViewHasControlledFanoutFirst(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTHasControlledFanout<GraphView, false>(iters, num_fanins,
num_fanouts);
void BM_GraphViewHasControlledFanoutFirst(::testing::benchmark::State& state) {
BM_GraphViewTHasControlledFanout<GraphView, false>(state);
}
static void BM_GraphViewHasControlledFanoutLast(int iters, int num_fanins,
int num_fanouts) {
BM_GraphViewTHasControlledFanout<GraphView, true>(iters, num_fanins,
num_fanouts);
void BM_GraphViewHasControlledFanoutLast(::testing::benchmark::State& state) {
BM_GraphViewTHasControlledFanout<GraphView, true>(state);
}
static void BM_MutableGraphViewHasControlledFanoutFirst(int iters,
int num_fanins,
int num_fanouts) {
BM_GraphViewTHasControlledFanout<MutableGraphView, false>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewHasControlledFanoutFirst(
::testing::benchmark::State& state) {
BM_GraphViewTHasControlledFanout<MutableGraphView, false>(state);
}
static void BM_MutableGraphViewHasControlledFanoutLast(int iters,
int num_fanins,
int num_fanouts) {
BM_GraphViewTHasControlledFanout<MutableGraphView, true>(iters, num_fanins,
num_fanouts);
void BM_MutableGraphViewHasControlledFanoutLast(
::testing::benchmark::State& state) {
BM_GraphViewTHasControlledFanout<MutableGraphView, true>(state);
}
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControlledFanoutFirst);
@ -2925,19 +2865,17 @@ RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_GraphViewHasControlledFanoutLast);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutFirst);
RUN_NUM_FANIN_NUM_FANOUT_BENCHMARK(BM_MutableGraphViewHasControlledFanoutLast);
static void BM_SortTopologically(int iters, int size) {
testing::StopTiming();
void BM_SortTopologically(::testing::benchmark::State& state) {
const int size = state.range(0);
GraphDef graph = test::CreateRandomGraph(size);
Status status;
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
testing::StartTiming();
for (int i = 0; i < iters; i++) {
for (auto i : state) {
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
}
testing::StopTiming();
}
RUN_NUM_NODE_BENCHMARK(BM_SortTopologically);

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/benchmark_testlib.h"
@ -196,19 +197,17 @@ TEST_F(TopologicalSortTest, ExtraDependencies) {
ComputeTopologicalOrder(graph, extra_dependencies, &topo_order).ok());
}
static void BM_ComputeTopologicalOrder(int iters, int size) {
testing::StopTiming();
static void BM_ComputeTopologicalOrder(::testing::benchmark::State& state) {
const int size = state.range(0);
GraphDef graph = test::CreateRandomGraph(size);
testing::StartTiming();
std::vector<const NodeDef*> topo_order;
for (int i = 0; i < iters; i++) {
for (auto s : state) {
topo_order.clear();
Status st = ComputeTopologicalOrder(graph, &topo_order);
CHECK(st.ok()) << "Failed to compute topological order";
}
testing::StopTiming();
}
BENCHMARK(BM_ComputeTopologicalOrder)
->Arg(10)

View File

@ -470,15 +470,16 @@ TEST(IsKernelRegisteredForNode, All) {
EXPECT_FALSE(IsKernelRegisteredForNode(node).ok());
}
#define BM_NodePositionIfSameNode(I, N, NAME) \
static void BM_NodePositionIfSameNode_##NAME(int iters) { \
string input = I; \
string node = N; \
for (int i = 0; i < iters; ++i) { \
const int pos = NodePositionIfSameNode(input, node); \
CHECK_GT(pos, -3); \
} \
} \
#define BM_NodePositionIfSameNode(I, N, NAME) \
static void BM_NodePositionIfSameNode_##NAME( \
::testing::benchmark::State& state) { \
string input = I; \
string node = N; \
for (auto s : state) { \
const int pos = NodePositionIfSameNode(input, node); \
CHECK_GT(pos, -3); \
} \
} \
BENCHMARK(BM_NodePositionIfSameNode_##NAME)
BM_NodePositionIfSameNode("foo/bar/baz:7", "foo/bar/baz", Match_7);
@ -487,10 +488,12 @@ BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl);
BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0);
BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end);
static void BM_NodeNameAsStringPiece(int iters, int size) {
void BM_NodeNameAsStringPiece(::testing::benchmark::State& state) {
const int size = state.range(0);
string input(size + 3, 'x');
input[size] = ':';
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
StringPiece node_name = NodeNameAsStringPiece(input);
CHECK_GT(node_name.size(), 0);
}
@ -498,9 +501,10 @@ static void BM_NodeNameAsStringPiece(int iters, int size) {
BENCHMARK(BM_NodeNameAsStringPiece)->Range(1, 1024);
#define BM_ParseNodeNameAsStringPiece(I, NAME) \
static void BM_ParseNodeNameAsStringPiece_##NAME(int iters) { \
static void BM_ParseNodeNameAsStringPiece_##NAME( \
::testing::benchmark::State& state) { \
string input = I; \
for (int i = 0; i < iters; ++i) { \
for (auto s : state) { \
int position; \
const StringPiece name = ParseNodeNameAsStringPiece(input, &position); \
CHECK_GE(position, -1); \
@ -683,25 +687,23 @@ TEST(SetTensorValueTest, Quantized) {
/*error_msg=*/"");
}
static void BM_NodeMapConstruct(int iters, int size) {
testing::StopTiming();
void BM_NodeMapConstruct(::testing::benchmark::State& state) {
const int size = state.range(0);
GraphDef graph = test::CreateRandomGraph(size);
testing::StartTiming();
for (int i = 0; i < iters; i++) {
for (auto s : state) {
NodeMap node_map(&graph);
}
testing::StopTiming();
}
BENCHMARK(BM_NodeMapConstruct)->Range(1, 1 << 20);
static void BM_ImmutableNodeMapConstruct(int iters, int size) {
testing::StopTiming();
void BM_ImmutableNodeMapConstruct(::testing::benchmark::State& state) {
const int size = state.range(0);
GraphDef graph = test::CreateRandomGraph(size);
testing::StartTiming();
for (int i = 0; i < iters; i++) {
for (auto s : state) {
ImmutableNodeMap node_map(&graph);
}
testing::StopTiming();
}
BENCHMARK(BM_ImmutableNodeMapConstruct)->Range(1, 1 << 20);