Update benchmarks to use newer API

PiperOrigin-RevId: 344103956
Change-Id: Ie7a0a02fbadc7cc7dc74e07d4202f241a6ead1f6
This commit is contained in:
A. Unique TensorFlower 2020-11-24 11:55:05 -08:00 committed by TensorFlower Gardener
parent ace0c15a22
commit 38ab98e198

View File

@ -511,14 +511,16 @@ TEST_F(GraphCyclesTest, CanContractEdge) {
EXPECT_TRUE(g_.CanContractEdge(3, 4));
}
static void BM_StressTest(int iters, int num_nodes) {
while (iters > 0) {
static void BM_StressTest(::testing::benchmark::State &state) {
const int num_nodes = state.range(0);
for (auto s : state) {
tensorflow::GraphCycles g;
int32 *nodes = new int32[num_nodes];
for (int i = 0; i < num_nodes; i++) {
nodes[i] = g.NewNode();
}
for (int i = 0; i < num_nodes && iters > 0; i++, iters--) {
for (int i = 0; i < num_nodes; i++) {
int end = std::min(num_nodes, i + 5);
for (int j = i + 1; j < end; j++) {
if (nodes[i] >= 0 && nodes[j] >= 0) {
@ -531,9 +533,11 @@ static void BM_StressTest(int iters, int num_nodes) {
}
BENCHMARK(BM_StressTest)->Range(2048, 1048576);
static void BM_ContractEdge(int iters, int num_nodes) {
while (iters-- > 0) {
tensorflow::testing::StopTiming();
static void BM_ContractEdge(::testing::benchmark::State &state) {
const int num_nodes = state.range(0);
for (auto s : state) {
state.PauseTiming();
tensorflow::GraphCycles g;
std::vector<int32> nodes;
nodes.reserve(num_nodes);
@ -545,7 +549,7 @@ static void BM_ContractEdge(int iters, int num_nodes) {
g.InsertEdge(nodes[i], nodes[num_nodes - 1]);
}
tensorflow::testing::StartTiming();
state.ResumeTiming();
int node = num_nodes - 1;
for (int i = 0; i < num_nodes - 1; ++i) {
node = g.ContractEdge(nodes[i], node).value();