Fixed static_schedule_test to not be a change detector for the node-based cost estimator.

PiperOrigin-RevId: 305061845
Change-Id: I61f6dcf96023cbf977b6df02e5de70c7de846acd
This commit is contained in:
A. Unique TensorFlower 2020-04-06 10:30:16 -07:00 committed by TensorFlower Gardener
parent 5068c9c94b
commit f87850a654

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/static_schedule.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
@ -45,6 +46,42 @@ class StaticScheduleTest : public ::testing::Test {
}
};
// Returns the completion times of the nodes in completion order.
std::vector<Costs::NanoSeconds> GetOrderedTimes(
const std::unordered_map<const NodeDef*, Costs::NanoSeconds>
completion_times) {
std::map<Costs::NanoSeconds, std::string> ordered_completion_times;
for (const auto& node_def_time : completion_times) {
ordered_completion_times[node_def_time.second] =
node_def_time.first->name();
}
std::vector<Costs::NanoSeconds> ordered_times;
for (const auto& time_node_name : ordered_completion_times) {
ordered_times.push_back(time_node_name.first);
}
return ordered_times;
}
// Returns the names of the completed nodes in completion order.
std::vector<std::string> GetOrderedNodeNames(
const std::unordered_map<const NodeDef*, Costs::NanoSeconds>
completion_times) {
std::map<Costs::NanoSeconds, std::string> ordered_completion_times;
for (const auto& node_def_time : completion_times) {
ordered_completion_times[node_def_time.second] =
node_def_time.first->name();
}
std::vector<std::string> ordered_node_names;
for (const auto& time_node_name : ordered_completion_times) {
ordered_node_names.push_back(time_node_name.second);
}
return ordered_node_names;
}
TEST_F(StaticScheduleTest, BasicGraph) {
// This trivial graph is so basic there's nothing to prune.
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
@ -60,23 +97,22 @@ TEST_F(StaticScheduleTest, BasicGraph) {
EXPECT_EQ(item.graph.node_size(), completion_times.size());
for (auto time : completion_times) {
if (time.first->name() == "Const/Const") {
EXPECT_EQ(Costs::NanoSeconds(1), time.second);
} else if (time.first->name() == "x") {
EXPECT_EQ(Costs::NanoSeconds(1500001), time.second);
} else if (time.first->name() == "Square") {
EXPECT_EQ(Costs::NanoSeconds(4000004), time.second);
} else if (time.first->name() == "Square_1") {
EXPECT_EQ(Costs::NanoSeconds(6500007), time.second);
} else if (time.first->name() == "Square_2") {
EXPECT_EQ(Costs::NanoSeconds(9000010), time.second);
} else if (time.first->name() == "Square_3") {
EXPECT_EQ(Costs::NanoSeconds(11500013), time.second);
} else if (time.first->name() == "y") {
EXPECT_EQ(Costs::NanoSeconds(14000013), time.second);
// Check that the completion times are strictly ascending, starting at 1.
std::vector<Costs::NanoSeconds> ordered_times =
GetOrderedTimes(completion_times);
for (int i = 0; i < ordered_times.size(); ++i) {
if (i > 0) {
EXPECT_GT(ordered_times[i], ordered_times[i - 1]);
}
}
EXPECT_EQ(ordered_times[0], Costs::NanoSeconds(1));
// Check that the completions schedule is as expected.
std::vector<std::string> ordered_node_names =
GetOrderedNodeNames(completion_times);
EXPECT_EQ(ordered_node_names,
(std::vector<std::string>{"Const/Const", "x", "Square", "Square_1",
"Square_2", "Square_3", "y"}));
}
TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) {
@ -106,19 +142,21 @@ TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) {
EXPECT_EQ(item.graph.node_size(), completion_times.size());
for (auto time : completion_times) {
if (time.first->name() == "a") {
EXPECT_EQ(Costs::NanoSeconds(1), time.second);
} else if (time.first->name() == "b") {
EXPECT_EQ(Costs::NanoSeconds(25000001), time.second);
} else if (time.first->name() == "c") {
EXPECT_EQ(Costs::NanoSeconds(25000002), time.second);
} else if (time.first->name() == "d") {
EXPECT_EQ(Costs::NanoSeconds(25000003), time.second);
} else if (time.first->name() == "e") {
EXPECT_EQ(Costs::NanoSeconds(50000003), time.second);
// Check that the completion times are strictly ascending, starting at 1.
std::vector<Costs::NanoSeconds> ordered_times =
GetOrderedTimes(completion_times);
for (int i = 0; i < ordered_times.size(); ++i) {
if (i > 0) {
EXPECT_GT(ordered_times[i], ordered_times[i - 1]);
}
}
EXPECT_EQ(ordered_times[0], Costs::NanoSeconds(1));
// Check that the completions schedule is as expected.
std::vector<std::string> ordered_node_names =
GetOrderedNodeNames(completion_times);
EXPECT_EQ(ordered_node_names,
(std::vector<std::string>{"a", "b", "c", "d", "e"}));
}
TEST_F(StaticScheduleTest, RequiredTimes) {
@ -140,23 +178,22 @@ TEST_F(StaticScheduleTest, RequiredTimes) {
EXPECT_EQ(item.graph.node_size(), required_times.size());
for (auto time : required_times) {
if (time.first->name() == "Const/Const") {
EXPECT_EQ(Costs::NanoSeconds(-14000012), time.second);
} else if (time.first->name() == "x") {
EXPECT_EQ(Costs::NanoSeconds(-12500012), time.second);
} else if (time.first->name() == "Square") {
EXPECT_EQ(Costs::NanoSeconds(-10000009), time.second);
} else if (time.first->name() == "Square_1") {
EXPECT_EQ(Costs::NanoSeconds(-7500006), time.second);
} else if (time.first->name() == "Square_2") {
EXPECT_EQ(Costs::NanoSeconds(-5000003), time.second);
} else if (time.first->name() == "Square_3") {
EXPECT_EQ(Costs::NanoSeconds(-2500000), time.second);
} else if (time.first->name() == "y") {
EXPECT_EQ(Costs::NanoSeconds(0), time.second);
// Check that the expecution times are strictly ascending, ending at 0.
std::vector<Costs::NanoSeconds> ordered_times =
GetOrderedTimes(required_times);
for (int i = 0; i < ordered_times.size(); ++i) {
if (i > 0) {
EXPECT_GT(ordered_times[i], ordered_times[i - 1]);
}
}
EXPECT_EQ(ordered_times[ordered_times.size() - 1], Costs::NanoSeconds(0));
// Check that the completions schedule is as expected.
std::vector<std::string> ordered_node_names =
GetOrderedNodeNames(required_times);
EXPECT_EQ(ordered_node_names,
(std::vector<std::string>{"Const/Const", "x", "Square", "Square_1",
"Square_2", "Square_3", "y"}));
}
} // namespace