Internal tests cleanup

PiperOrigin-RevId: 339888534
Change-Id: I60fb13fad42ec5938d7276151afdce2dc3a6926f
This commit is contained in:
A. Unique TensorFlower 2020-10-30 09:49:55 -07:00 committed by TensorFlower Gardener
parent 5cd75eedd5
commit 556c1a4a58
5 changed files with 42 additions and 46 deletions

View File

@ -203,21 +203,21 @@ TEST(AlgorithmTest, PostOrderWithEdgeFilter) {
}
}
static void BM_PruneForReverseReachability(int iters, int num_nodes,
int num_edges_per_node) {
testing::StopTiming();
void BM_PruneForReverseReachability(::testing::benchmark::State& state) {
const int num_nodes = state.range(0);
const int num_edges_per_node = state.range(1);
const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node);
const auto registry = OpRegistry::Global();
GraphConstructorOptions opts;
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
state.PauseTiming();
Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
std::unordered_set<const Node*> visited;
visited.insert(graph.FindNodeId(graph.num_nodes() - 1));
testing::StartTiming();
state.ResumeTiming();
PruneForReverseReachability(&graph, std::move(visited));
testing::StopTiming();
}
}
BENCHMARK(BM_PruneForReverseReachability)->ArgPair(10, 2);

View File

@ -661,9 +661,9 @@ TEST_F(GraphTest, BuildNodeNameIndex) {
}
}
static void BM_InEdgeIteration(int iters, int num_nodes,
int num_edges_per_node) {
testing::StopTiming();
void BM_InEdgeIteration(::testing::benchmark::State& state) {
const int num_nodes = state.range(0);
const int num_edges_per_node = state.range(1);
const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node);
Graph graph(OpRegistry::Global());
@ -671,8 +671,7 @@ static void BM_InEdgeIteration(int iters, int num_nodes,
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
int64 sum = 0;
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
for (const Node* node : graph.nodes()) {
for (auto e : node->in_edges()) {
sum += e->id();
@ -680,7 +679,6 @@ static void BM_InEdgeIteration(int iters, int num_nodes,
}
}
VLOG(1) << sum;
testing::StopTiming();
}
BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 2);
BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 2);
@ -703,8 +701,9 @@ BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 16);
BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 16);
BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 16);
static void BM_GraphCreation(int iters, int num_nodes, int num_edges_per_node) {
testing::StopTiming();
void BM_GraphCreation(::testing::benchmark::State& state) {
const int num_nodes = state.range(0);
const int num_edges_per_node = state.range(1);
const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node);
const auto registry = OpRegistry::Global();
@ -713,14 +712,12 @@ static void BM_GraphCreation(int iters, int num_nodes, int num_edges_per_node) {
Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
int64 sum = 0;
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
sum += graph.num_node_ids();
}
VLOG(1) << sum;
testing::StopTiming();
}
BENCHMARK(BM_GraphCreation)->ArgPair(10, 2);
BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 2);
@ -743,8 +740,9 @@ BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 16);
BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 16);
BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 16);
static void BM_ToGraphDef(int iters, int num_nodes, int num_edges_per_node) {
testing::StopTiming();
void BM_ToGraphDef(::testing::benchmark::State& state) {
const int num_nodes = state.range(0);
const int num_edges_per_node = state.range(1);
const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node);
const auto registry = OpRegistry::Global();
@ -753,14 +751,12 @@ static void BM_ToGraphDef(int iters, int num_nodes, int num_edges_per_node) {
Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
int64 sum = 0;
testing::StartTiming();
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
GraphDef graph_def;
graph.ToGraphDef(&graph_def);
sum += graph_def.node_size();
}
VLOG(1) << sum;
testing::StopTiming();
}
BENCHMARK(BM_ToGraphDef)->ArgPair(10, 2);
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 2);
@ -783,20 +779,20 @@ BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16);
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16);
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16);
static void BM_RemoveNode(int iters, int num_nodes, int num_edges_per_node) {
testing::StopTiming();
void BM_RemoveNode(::testing::benchmark::State& state) {
const int num_nodes = state.range(0);
const int num_edges_per_node = state.range(1);
const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node);
const auto registry = OpRegistry::Global();
GraphConstructorOptions opts;
for (int i = 0; i < iters; ++i) {
for (auto s : state) {
Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
testing::StartTiming();
for (Node* n : graph.op_nodes()) {
graph.RemoveNode(n);
}
testing::StopTiming();
}
}
BENCHMARK(BM_RemoveNode)->ArgPair(10, 2);

View File

@ -347,8 +347,8 @@ TEST_F(OptimizerCSETest, Constant_Dedup) {
EXPECT_EQ(node_set.count("n/_3(Const)") + node_set.count("n/_4(Const)"), 1);
}
static void BM_CSE(int iters, int op_nodes) {
testing::StopTiming();
void BM_CSE(::testing::benchmark::State& state) {
const int op_nodes = state.range(0);
string s;
for (int in = 0; in < 10; in++) {
s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
@ -363,7 +363,8 @@ static void BM_CSE(int iters, int op_nodes) {
}
bool first = true;
while (iters > 0) {
for (auto i : state) {
state.PauseTiming();
Graph* graph = new Graph(OpRegistry::Global());
InitGraph(s, graph);
int N = graph->num_node_ids();
@ -372,13 +373,12 @@ static void BM_CSE(int iters, int op_nodes) {
first = false;
}
{
testing::StartTiming();
state.ResumeTiming();
OptimizeCSE(graph, nullptr);
testing::StopTiming();
state.PauseTiming();
}
iters -= N; // Our benchmark units are individual graph nodes,
// not whole graphs
delete graph;
state.ResumeTiming();
}
}
BENCHMARK(BM_CSE)->Arg(1000)->Arg(10000);

View File

@ -342,14 +342,14 @@ TEST_F(SubgraphTest, Errors) {
REGISTER_OP("In").Output("o: float");
REGISTER_OP("Op").Input("i: float").Output("o: float");
static void BM_SubgraphHelper(int iters, int num_nodes,
bool use_function_convention) {
void BM_SubgraphHelper(::testing::benchmark::State& state,
bool use_function_convention) {
const int num_nodes = state.range(0);
DeviceAttributes device_info;
device_info.set_name("/job:a/replica:0/task:0/cpu:0");
device_info.set_device_type(DeviceType(DEVICE_CPU).type());
device_info.set_incarnation(0);
testing::StopTiming();
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@ -371,8 +371,8 @@ static void BM_SubgraphHelper(int iters, int num_nodes,
}
std::vector<string> fetch;
std::vector<string> targets = {strings::StrCat("N", num_nodes - 1)};
testing::StartTiming();
while (--iters > 0) {
for (auto s : state) {
Graph* subgraph = new Graph(OpRegistry::Global());
CopyGraph(g, subgraph);
subgraph::RewriteGraphMetadata metadata;
@ -383,11 +383,11 @@ static void BM_SubgraphHelper(int iters, int num_nodes,
}
}
static void BM_Subgraph(int iters, int num_nodes) {
BM_SubgraphHelper(iters, num_nodes, false /* use_function_convention */);
void BM_Subgraph(::testing::benchmark::State& state) {
BM_SubgraphHelper(state, false /* use_function_convention */);
}
static void BM_SubgraphFunctionConvention(int iters, int num_nodes) {
BM_SubgraphHelper(iters, num_nodes, true /* use_function_convention */);
void BM_SubgraphFunctionConvention(::testing::benchmark::State& state) {
BM_SubgraphHelper(state, true /* use_function_convention */);
}
BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000);
BENCHMARK(BM_SubgraphFunctionConvention)

View File

@ -39,8 +39,8 @@ uint32 Skewed(random::SimplePhilox* rnd, int max_log) {
return rnd->Rand32() % space;
}
void BM_ParseTensorName(int iters, int arg) {
testing::StopTiming();
void BM_ParseTensorName(::testing::benchmark::State& state) {
const int arg = state.range(0);
random::PhiloxRandom philox(301, 17);
random::SimplePhilox rnd(&philox);
std::vector<string> names;
@ -78,11 +78,11 @@ void BM_ParseTensorName(int iters, int arg) {
}
names.push_back(name);
}
testing::StartTiming();
TensorId id;
int index = 0;
int sum = 0;
while (--iters > 0) {
for (auto s : state) {
id = ParseTensorName(names[index++ % names.size()]);
sum += id.second;
}