Internal tests cleanup

PiperOrigin-RevId: 339888534
Change-Id: I60fb13fad42ec5938d7276151afdce2dc3a6926f
This commit is contained in:
A. Unique TensorFlower 2020-10-30 09:49:55 -07:00 committed by TensorFlower Gardener
parent 5cd75eedd5
commit 556c1a4a58
5 changed files with 42 additions and 46 deletions

View File

@ -203,21 +203,21 @@ TEST(AlgorithmTest, PostOrderWithEdgeFilter) {
} }
} }
static void BM_PruneForReverseReachability(int iters, int num_nodes, void BM_PruneForReverseReachability(::testing::benchmark::State& state) {
int num_edges_per_node) { const int num_nodes = state.range(0);
testing::StopTiming(); const int num_edges_per_node = state.range(1);
const GraphDef graph_def = const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node); test::CreateGraphDef(num_nodes, num_edges_per_node);
const auto registry = OpRegistry::Global(); const auto registry = OpRegistry::Global();
GraphConstructorOptions opts; GraphConstructorOptions opts;
for (int i = 0; i < iters; ++i) { for (auto s : state) {
state.PauseTiming();
Graph graph(registry); Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
std::unordered_set<const Node*> visited; std::unordered_set<const Node*> visited;
visited.insert(graph.FindNodeId(graph.num_nodes() - 1)); visited.insert(graph.FindNodeId(graph.num_nodes() - 1));
testing::StartTiming(); state.ResumeTiming();
PruneForReverseReachability(&graph, std::move(visited)); PruneForReverseReachability(&graph, std::move(visited));
testing::StopTiming();
} }
} }
BENCHMARK(BM_PruneForReverseReachability)->ArgPair(10, 2); BENCHMARK(BM_PruneForReverseReachability)->ArgPair(10, 2);

View File

@ -661,9 +661,9 @@ TEST_F(GraphTest, BuildNodeNameIndex) {
} }
} }
static void BM_InEdgeIteration(int iters, int num_nodes, void BM_InEdgeIteration(::testing::benchmark::State& state) {
int num_edges_per_node) { const int num_nodes = state.range(0);
testing::StopTiming(); const int num_edges_per_node = state.range(1);
const GraphDef graph_def = const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node); test::CreateGraphDef(num_nodes, num_edges_per_node);
Graph graph(OpRegistry::Global()); Graph graph(OpRegistry::Global());
@ -671,8 +671,7 @@ static void BM_InEdgeIteration(int iters, int num_nodes,
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
int64 sum = 0; int64 sum = 0;
testing::StartTiming(); for (auto s : state) {
for (int i = 0; i < iters; ++i) {
for (const Node* node : graph.nodes()) { for (const Node* node : graph.nodes()) {
for (auto e : node->in_edges()) { for (auto e : node->in_edges()) {
sum += e->id(); sum += e->id();
@ -680,7 +679,6 @@ static void BM_InEdgeIteration(int iters, int num_nodes,
} }
} }
VLOG(1) << sum; VLOG(1) << sum;
testing::StopTiming();
} }
BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 2); BENCHMARK(BM_InEdgeIteration)->ArgPair(10, 2);
BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 2); BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 6, 2);
@ -703,8 +701,9 @@ BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 9, 16);
BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 16); BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 12, 16);
BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 16); BENCHMARK(BM_InEdgeIteration)->ArgPair(1 << 15, 16);
static void BM_GraphCreation(int iters, int num_nodes, int num_edges_per_node) { void BM_GraphCreation(::testing::benchmark::State& state) {
testing::StopTiming(); const int num_nodes = state.range(0);
const int num_edges_per_node = state.range(1);
const GraphDef graph_def = const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node); test::CreateGraphDef(num_nodes, num_edges_per_node);
const auto registry = OpRegistry::Global(); const auto registry = OpRegistry::Global();
@ -713,14 +712,12 @@ static void BM_GraphCreation(int iters, int num_nodes, int num_edges_per_node) {
Graph graph(registry); Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
int64 sum = 0; int64 sum = 0;
testing::StartTiming(); for (auto s : state) {
for (int i = 0; i < iters; ++i) {
Graph graph(registry); Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
sum += graph.num_node_ids(); sum += graph.num_node_ids();
} }
VLOG(1) << sum; VLOG(1) << sum;
testing::StopTiming();
} }
BENCHMARK(BM_GraphCreation)->ArgPair(10, 2); BENCHMARK(BM_GraphCreation)->ArgPair(10, 2);
BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 2); BENCHMARK(BM_GraphCreation)->ArgPair(1 << 6, 2);
@ -743,8 +740,9 @@ BENCHMARK(BM_GraphCreation)->ArgPair(1 << 9, 16);
BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 16); BENCHMARK(BM_GraphCreation)->ArgPair(1 << 12, 16);
BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 16); BENCHMARK(BM_GraphCreation)->ArgPair(1 << 15, 16);
static void BM_ToGraphDef(int iters, int num_nodes, int num_edges_per_node) { void BM_ToGraphDef(::testing::benchmark::State& state) {
testing::StopTiming(); const int num_nodes = state.range(0);
const int num_edges_per_node = state.range(1);
const GraphDef graph_def = const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node); test::CreateGraphDef(num_nodes, num_edges_per_node);
const auto registry = OpRegistry::Global(); const auto registry = OpRegistry::Global();
@ -753,14 +751,12 @@ static void BM_ToGraphDef(int iters, int num_nodes, int num_edges_per_node) {
Graph graph(registry); Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
int64 sum = 0; int64 sum = 0;
testing::StartTiming(); for (auto s : state) {
for (int i = 0; i < iters; ++i) {
GraphDef graph_def; GraphDef graph_def;
graph.ToGraphDef(&graph_def); graph.ToGraphDef(&graph_def);
sum += graph_def.node_size(); sum += graph_def.node_size();
} }
VLOG(1) << sum; VLOG(1) << sum;
testing::StopTiming();
} }
BENCHMARK(BM_ToGraphDef)->ArgPair(10, 2); BENCHMARK(BM_ToGraphDef)->ArgPair(10, 2);
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 2); BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 6, 2);
@ -783,20 +779,20 @@ BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 9, 16);
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16); BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 12, 16);
BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16); BENCHMARK(BM_ToGraphDef)->ArgPair(1 << 15, 16);
static void BM_RemoveNode(int iters, int num_nodes, int num_edges_per_node) { void BM_RemoveNode(::testing::benchmark::State& state) {
testing::StopTiming(); const int num_nodes = state.range(0);
const int num_edges_per_node = state.range(1);
const GraphDef graph_def = const GraphDef graph_def =
test::CreateGraphDef(num_nodes, num_edges_per_node); test::CreateGraphDef(num_nodes, num_edges_per_node);
const auto registry = OpRegistry::Global(); const auto registry = OpRegistry::Global();
GraphConstructorOptions opts; GraphConstructorOptions opts;
for (int i = 0; i < iters; ++i) { for (auto s : state) {
Graph graph(registry); Graph graph(registry);
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph)); TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, &graph));
testing::StartTiming(); testing::StartTiming();
for (Node* n : graph.op_nodes()) { for (Node* n : graph.op_nodes()) {
graph.RemoveNode(n); graph.RemoveNode(n);
} }
testing::StopTiming();
} }
} }
BENCHMARK(BM_RemoveNode)->ArgPair(10, 2); BENCHMARK(BM_RemoveNode)->ArgPair(10, 2);

View File

@ -347,8 +347,8 @@ TEST_F(OptimizerCSETest, Constant_Dedup) {
EXPECT_EQ(node_set.count("n/_3(Const)") + node_set.count("n/_4(Const)"), 1); EXPECT_EQ(node_set.count("n/_3(Const)") + node_set.count("n/_4(Const)"), 1);
} }
static void BM_CSE(int iters, int op_nodes) { void BM_CSE(::testing::benchmark::State& state) {
testing::StopTiming(); const int op_nodes = state.range(0);
string s; string s;
for (int in = 0; in < 10; in++) { for (int in = 0; in < 10; in++) {
s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in); s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
@ -363,7 +363,8 @@ static void BM_CSE(int iters, int op_nodes) {
} }
bool first = true; bool first = true;
while (iters > 0) { for (auto i : state) {
state.PauseTiming();
Graph* graph = new Graph(OpRegistry::Global()); Graph* graph = new Graph(OpRegistry::Global());
InitGraph(s, graph); InitGraph(s, graph);
int N = graph->num_node_ids(); int N = graph->num_node_ids();
@ -372,13 +373,12 @@ static void BM_CSE(int iters, int op_nodes) {
first = false; first = false;
} }
{ {
testing::StartTiming(); state.ResumeTiming();
OptimizeCSE(graph, nullptr); OptimizeCSE(graph, nullptr);
testing::StopTiming(); state.PauseTiming();
} }
iters -= N; // Our benchmark units are individual graph nodes,
// not whole graphs
delete graph; delete graph;
state.ResumeTiming();
} }
} }
BENCHMARK(BM_CSE)->Arg(1000)->Arg(10000); BENCHMARK(BM_CSE)->Arg(1000)->Arg(10000);

View File

@ -342,14 +342,14 @@ TEST_F(SubgraphTest, Errors) {
REGISTER_OP("In").Output("o: float"); REGISTER_OP("In").Output("o: float");
REGISTER_OP("Op").Input("i: float").Output("o: float"); REGISTER_OP("Op").Input("i: float").Output("o: float");
static void BM_SubgraphHelper(int iters, int num_nodes, void BM_SubgraphHelper(::testing::benchmark::State& state,
bool use_function_convention) { bool use_function_convention) {
const int num_nodes = state.range(0);
DeviceAttributes device_info; DeviceAttributes device_info;
device_info.set_name("/job:a/replica:0/task:0/cpu:0"); device_info.set_name("/job:a/replica:0/task:0/cpu:0");
device_info.set_device_type(DeviceType(DEVICE_CPU).type()); device_info.set_device_type(DeviceType(DEVICE_CPU).type());
device_info.set_incarnation(0); device_info.set_incarnation(0);
testing::StopTiming();
Graph g(OpRegistry::Global()); Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g. { // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately); GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@ -371,8 +371,8 @@ static void BM_SubgraphHelper(int iters, int num_nodes,
} }
std::vector<string> fetch; std::vector<string> fetch;
std::vector<string> targets = {strings::StrCat("N", num_nodes - 1)}; std::vector<string> targets = {strings::StrCat("N", num_nodes - 1)};
testing::StartTiming();
while (--iters > 0) { for (auto s : state) {
Graph* subgraph = new Graph(OpRegistry::Global()); Graph* subgraph = new Graph(OpRegistry::Global());
CopyGraph(g, subgraph); CopyGraph(g, subgraph);
subgraph::RewriteGraphMetadata metadata; subgraph::RewriteGraphMetadata metadata;
@ -383,11 +383,11 @@ static void BM_SubgraphHelper(int iters, int num_nodes,
} }
} }
static void BM_Subgraph(int iters, int num_nodes) { void BM_Subgraph(::testing::benchmark::State& state) {
BM_SubgraphHelper(iters, num_nodes, false /* use_function_convention */); BM_SubgraphHelper(state, false /* use_function_convention */);
} }
static void BM_SubgraphFunctionConvention(int iters, int num_nodes) { void BM_SubgraphFunctionConvention(::testing::benchmark::State& state) {
BM_SubgraphHelper(iters, num_nodes, true /* use_function_convention */); BM_SubgraphHelper(state, true /* use_function_convention */);
} }
BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000); BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000);
BENCHMARK(BM_SubgraphFunctionConvention) BENCHMARK(BM_SubgraphFunctionConvention)

View File

@ -39,8 +39,8 @@ uint32 Skewed(random::SimplePhilox* rnd, int max_log) {
return rnd->Rand32() % space; return rnd->Rand32() % space;
} }
void BM_ParseTensorName(int iters, int arg) { void BM_ParseTensorName(::testing::benchmark::State& state) {
testing::StopTiming(); const int arg = state.range(0);
random::PhiloxRandom philox(301, 17); random::PhiloxRandom philox(301, 17);
random::SimplePhilox rnd(&philox); random::SimplePhilox rnd(&philox);
std::vector<string> names; std::vector<string> names;
@ -78,11 +78,11 @@ void BM_ParseTensorName(int iters, int arg) {
} }
names.push_back(name); names.push_back(name);
} }
testing::StartTiming();
TensorId id; TensorId id;
int index = 0; int index = 0;
int sum = 0; int sum = 0;
while (--iters > 0) { for (auto s : state) {
id = ParseTensorName(names[index++ % names.size()]); id = ParseTensorName(names[index++ % names.size()]);
sum += id.second; sum += id.second;
} }