diff --git a/tensorflow/core/graph/algorithm_test.cc b/tensorflow/core/graph/algorithm_test.cc index faefb0b82e9..7fc4abb5492 100644 --- a/tensorflow/core/graph/algorithm_test.cc +++ b/tensorflow/core/graph/algorithm_test.cc @@ -203,21 +203,21 @@ TEST(AlgorithmTest, PostOrderWithEdgeFilter) { } } -static void BM_PruneForReverseReachability(int iters, int num_nodes, - int num_edges_per_node) { - testing::StopTiming(); +void BM_PruneForReverseReachability(::testing::benchmark::State& state) { + const int num_nodes = state.range(0); + const int num_edges_per_node = state.range(1); const GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node); const auto registry = OpRegistry::Global(); GraphConstructorOptions opts; - for (int i = 0; i < iters; ++i) { + for (auto s : state) { + state.PauseTiming(); Graph graph(registry); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); std::unordered_set visited; visited.insert(graph.FindNodeId(graph.num_nodes() - 1)); - testing::StartTiming(); + state.ResumeTiming(); PruneForReverseReachability(&graph, std::move(visited)); - testing::StopTiming(); } } BENCHMARK(BM_PruneForReverseReachability)->ArgPair(10, 2); diff --git a/tensorflow/core/graph/graph_test.cc b/tensorflow/core/graph/graph_test.cc index a8b421367ab..2801bd7c961 100644 --- a/tensorflow/core/graph/graph_test.cc +++ b/tensorflow/core/graph/graph_test.cc @@ -661,9 +661,9 @@ TEST_F(GraphTest, BuildNodeNameIndex) { } } -static void BM_InEdgeIteration(int iters, int num_nodes, - int num_edges_per_node) { - testing::StopTiming(); +void BM_InEdgeIteration(::testing::benchmark::State& state) { + const int num_nodes = state.range(0); + const int num_edges_per_node = state.range(1); const GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node); Graph graph(OpRegistry::Global()); @@ -671,8 +671,7 @@ static void BM_InEdgeIteration(int iters, int num_nodes, TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); int64 sum = 0; - testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { for (const Node* node : graph.nodes()) { for (auto e : node->in_edges()) { sum += e->id(); @@ -680,7 +679,6 @@ static void BM_InEdgeIteration(int iters, int num_nodes, } } VLOG(1) << sum; - testing::StopTiming(); } BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 2); BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 2); @@ -703,8 +701,9 @@ BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 16); BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 16); BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 16); -static void BM_GraphCreation(int iters, int num_nodes, int num_edges_per_node) { - testing::StopTiming(); +void BM_GraphCreation(::testing::benchmark::State& state) { + const int num_nodes = state.range(0); + const int num_edges_per_node = state.range(1); const GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node); const auto registry = OpRegistry::Global(); @@ -713,14 +712,12 @@ static void BM_GraphCreation(int iters, int num_nodes, int num_edges_per_node) { Graph graph(registry); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); int64 sum = 0; - testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { Graph graph(registry); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); sum += graph.num_node_ids(); } VLOG(1) << sum; - testing::StopTiming(); } BENCHMARK(BM_GraphCreation)->ArgPair(10, 2); BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 2); @@ -743,8 +740,9 @@ BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 16); BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 16); BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 16); -static void BM_ToGraphDef(int iters, int num_nodes, int num_edges_per_node) { - testing::StopTiming(); +void BM_ToGraphDef(::testing::benchmark::State& state) { + const int num_nodes = state.range(0); + const int num_edges_per_node = state.range(1); const GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node); const auto registry = OpRegistry::Global(); @@ -753,14 +751,12 @@ static void BM_ToGraphDef(int iters, int num_nodes, int num_edges_per_node) { Graph graph(registry); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); int64 sum = 0; - testing::StartTiming(); - for (int i = 0; i < iters; ++i) { + for (auto s : state) { GraphDef graph_def; graph.ToGraphDef(&graph_def); sum += graph_def.node_size(); } VLOG(1) << sum; - testing::StopTiming(); } BENCHMARK(BM_ToGraphDef)->ArgPair(10, 2); BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 2); @@ -783,20 +779,20 @@ BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16); BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16); BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16); -static void BM_RemoveNode(int iters, int num_nodes, int num_edges_per_node) { - testing::StopTiming(); +void BM_RemoveNode(::testing::benchmark::State& state) { + const int num_nodes = state.range(0); + const int num_edges_per_node = state.range(1); const GraphDef graph_def = test::CreateGraphDef(num_nodes, num_edges_per_node); const auto registry = OpRegistry::Global(); GraphConstructorOptions opts; - for (int i = 0; i < iters; ++i) { + for (auto s : state) { Graph graph(registry); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); testing::StartTiming(); for (Node* n : graph.op_nodes()) { graph.RemoveNode(n); } - testing::StopTiming(); } } BENCHMARK(BM_RemoveNode)->ArgPair(10, 2); diff --git a/tensorflow/core/graph/optimizer_cse_test.cc b/tensorflow/core/graph/optimizer_cse_test.cc index 08292068efc..be10a9daa9f 100644 --- a/tensorflow/core/graph/optimizer_cse_test.cc +++ b/tensorflow/core/graph/optimizer_cse_test.cc @@ -347,8 +347,8 @@ TEST_F(OptimizerCSETest, Constant_Dedup) { EXPECT_EQ(node_set.count("n/_3(Const)") + node_set.count("n/_4(Const)"), 1); } -static void BM_CSE(int iters, int op_nodes) { - testing::StopTiming(); +void BM_CSE(::testing::benchmark::State& state) { + const int op_nodes = state.range(0); string s; for (int in = 0; in < 10; in++) { s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in); @@ -363,7 +363,8 @@ static void BM_CSE(int iters, int op_nodes) { } bool first = true; - while (iters > 0) { + for (auto i : state) { + state.PauseTiming(); Graph* graph = new Graph(OpRegistry::Global()); InitGraph(s, graph); int N = graph->num_node_ids(); @@ -372,13 +373,12 @@ static void BM_CSE(int iters, int op_nodes) { first = false; } { - testing::StartTiming(); + state.ResumeTiming(); OptimizeCSE(graph, nullptr); - testing::StopTiming(); + state.PauseTiming(); } - iters -= N; // Our benchmark units are individual graph nodes, - // not whole graphs delete graph; + state.ResumeTiming(); } } BENCHMARK(BM_CSE)->Arg(1000)->Arg(10000); diff --git a/tensorflow/core/graph/subgraph_test.cc b/tensorflow/core/graph/subgraph_test.cc index a8a834a0a83..571da3b62e5 100644 --- a/tensorflow/core/graph/subgraph_test.cc +++ b/tensorflow/core/graph/subgraph_test.cc @@ -342,14 +342,14 @@ TEST_F(SubgraphTest, Errors) { REGISTER_OP("In").Output("o: float"); REGISTER_OP("Op").Input("i: float").Output("o: float"); -static void BM_SubgraphHelper(int iters, int num_nodes, - bool use_function_convention) { +void BM_SubgraphHelper(::testing::benchmark::State& state, + bool use_function_convention) { + const int num_nodes = state.range(0); DeviceAttributes device_info; device_info.set_name("/job:a/replica:0/task:0/cpu:0"); device_info.set_device_type(DeviceType(DEVICE_CPU).type()); device_info.set_incarnation(0); - testing::StopTiming(); Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. GraphDefBuilder b(GraphDefBuilder::kFailImmediately); @@ -371,8 +371,8 @@ static void BM_SubgraphHelper(int iters, int num_nodes, } std::vector fetch; std::vector targets = {strings::StrCat("N", num_nodes - 1)}; - testing::StartTiming(); - while (--iters > 0) { + + for (auto s : state) { Graph* subgraph = new Graph(OpRegistry::Global()); CopyGraph(g, subgraph); subgraph::RewriteGraphMetadata metadata; @@ -383,11 +383,11 @@ static void BM_SubgraphHelper(int iters, int num_nodes, } } -static void BM_Subgraph(int iters, int num_nodes) { - BM_SubgraphHelper(iters, num_nodes, false /* use_function_convention */); +void BM_Subgraph(::testing::benchmark::State& state) { + BM_SubgraphHelper(state, false /* use_function_convention */); } -static void BM_SubgraphFunctionConvention(int iters, int num_nodes) { - BM_SubgraphHelper(iters, num_nodes, true /* use_function_convention */); +void BM_SubgraphFunctionConvention(::testing::benchmark::State& state) { + BM_SubgraphHelper(state, true /* use_function_convention */); } BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000); BENCHMARK(BM_SubgraphFunctionConvention) diff --git a/tensorflow/core/graph/tensor_id_test.cc b/tensorflow/core/graph/tensor_id_test.cc index 878afbe7d65..1b9247d3e48 100644 --- a/tensorflow/core/graph/tensor_id_test.cc +++ b/tensorflow/core/graph/tensor_id_test.cc @@ -39,8 +39,8 @@ uint32 Skewed(random::SimplePhilox* rnd, int max_log) { return rnd->Rand32() % space; } -void BM_ParseTensorName(int iters, int arg) { - testing::StopTiming(); +void BM_ParseTensorName(::testing::benchmark::State& state) { + const int arg = state.range(0); random::PhiloxRandom philox(301, 17); random::SimplePhilox rnd(&philox); std::vector names; @@ -78,11 +78,11 @@ void BM_ParseTensorName(int iters, int arg) { } names.push_back(name); } - testing::StartTiming(); + TensorId id; int index = 0; int sum = 0; - while (--iters > 0) { + for (auto s : state) { id = ParseTensorName(names[index++ % names.size()]); sum += id.second; }