Static required time computation

PiperOrigin-RevId: 164894645
This commit is contained in:
Benoit Steiner 2017-08-10 13:26:56 -07:00 committed by TensorFlower Gardener
parent 076158f9b9
commit c8897e9bce
3 changed files with 118 additions and 5 deletions

View File

@ -119,5 +119,71 @@ Status EstimateEarliestExecutionTimes(
return Status::OK();
}
Status EstimateRequiredTimes(
const GrapplerItem& item, const Cluster* cluster,
const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
execution_times,
std::unordered_map<const NodeDef*, Costs::NanoSeconds>* required_times) {
std::unordered_map<string, const NodeDef*> name_map;
for (const NodeDef& node : item.graph.node()) {
name_map[node.name()] = &node;
(*required_times)[&node] = Costs::NanoSeconds::max();
}
std::unordered_map<const NodeDef*, int> pending_fanouts;
for (const NodeDef& node : item.graph.node()) {
for (const string& input : node.input()) {
string node_name = NodeName(input);
auto it = name_map.find(node_name);
if (it == name_map.end()) {
return errors::InvalidArgument(
strings::StrCat("Unknown input node ", input));
}
const NodeDef* fanin = it->second;
pending_fanouts[fanin] += 1;
}
}
std::deque<const NodeDef*> ready_nodes;
for (const NodeDef& node : item.graph.node()) {
if (pending_fanouts[&node] == 0) {
auto it = execution_times.find(&node);
if (it != execution_times.end()) {
(*required_times)[&node] = it->second;
}
ready_nodes.push_back(&node);
}
}
GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically());
OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster);
while (!ready_nodes.empty()) {
const NodeDef* node = ready_nodes.front();
ready_nodes.pop_front();
Costs::NanoSeconds execution_time =
PredictExecutionTime(properties, estimator, placer, *node);
Costs::NanoSeconds required_time = (*required_times)[node] - execution_time;
for (const string& fanin_name : node->input()) {
const NodeDef* fanin = name_map[NodeName(fanin_name)];
(*required_times)[fanin] =
std::min((*required_times)[fanin], required_time);
int pending = pending_fanouts[fanin];
if (pending == 0) {
// Already processed. Avoid going through loops more than once.
continue;
} else if (pending == 1) {
ready_nodes.push_back(fanin);
}
pending_fanouts[fanin]--;
}
}
return Status::OK();
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -26,7 +26,7 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
// Compute the earliest time as which the execution of each node in the graph
// Compute the earliest time at which the execution of each node in the graph
// can complete.
// In our estimation, we ensure that each node takes at least one nanosecond to
// execute: therefore the execution times can be used to derive a topological
@ -35,6 +35,15 @@ Status EstimateEarliestExecutionTimes(
const GrapplerItem& item, const Cluster* cluster,
std::unordered_map<const NodeDef*, Costs::NanoSeconds>* execution_times);
// Compute the time by which the execution of each node must complete to ensure
// the subsequent nodes can still be executed by the times predicted by the
// EstimateEarliestExecutionTimes function.
Status EstimateRequiredTimes(
const GrapplerItem& item, const Cluster* cluster,
const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
execution_times,
std::unordered_map<const NodeDef*, Costs::NanoSeconds>* required_times);
} // namespace grappler
} // end namespace tensorflow

View File

@ -65,13 +65,13 @@ TEST_F(StaticScheduleTest, BasicGraph) {
EXPECT_EQ(Costs::NanoSeconds(1), time.second);
} else if (time.first->name() == "x") {
EXPECT_EQ(Costs::NanoSeconds(250002), time.second);
} else if (time.first->name() == "AddN") {
} else if (time.first->name() == "Square") {
EXPECT_EQ(Costs::NanoSeconds(1500005), time.second);
} else if (time.first->name() == "AddN_1") {
} else if (time.first->name() == "Square_1") {
EXPECT_EQ(Costs::NanoSeconds(2750008), time.second);
} else if (time.first->name() == "AddN_2") {
} else if (time.first->name() == "Square_2") {
EXPECT_EQ(Costs::NanoSeconds(4000011), time.second);
} else if (time.first->name() == "AddN_3") {
} else if (time.first->name() == "Square_3") {
EXPECT_EQ(Costs::NanoSeconds(5250014), time.second);
} else if (time.first->name() == "y") {
EXPECT_EQ(Costs::NanoSeconds(6500017), time.second);
@ -121,6 +121,44 @@ TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) {
}
}
TEST_F(StaticScheduleTest, RequiredTimes) {
// This trivial graph is so basic there's nothing to prune.
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
for (const NodeDef& node : item.graph.node()) {
execution_times[&node] = 0;
}
std::unordered_map<const NodeDef*, Costs::NanoSeconds> required_times;
Status status = EstimateRequiredTimes(item, cluster.get(), execution_times,
&required_times);
TF_EXPECT_OK(status);
EXPECT_EQ(item.graph.node_size(), required_times.size());
for (auto time : required_times) {
if (time.first->name() == "Const/Const") {
EXPECT_EQ(Costs::NanoSeconds(-6500016), time.second);
} else if (time.first->name() == "x") {
EXPECT_EQ(Costs::NanoSeconds(-6250015), time.second);
} else if (time.first->name() == "Square") {
EXPECT_EQ(Costs::NanoSeconds(-5000012), time.second);
} else if (time.first->name() == "Square_1") {
EXPECT_EQ(Costs::NanoSeconds(-3750009), time.second);
} else if (time.first->name() == "Square_2") {
EXPECT_EQ(Costs::NanoSeconds(-2500006), time.second);
} else if (time.first->name() == "Square_3") {
EXPECT_EQ(Costs::NanoSeconds(-1250003), time.second);
} else if (time.first->name() == "y") {
EXPECT_EQ(Costs::NanoSeconds(0), time.second);
}
}
}
} // namespace
} // namespace grappler
} // namespace tensorflow