From c8897e9bce9ae7a79f2fa5e4195aa6824f68ee95 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 10 Aug 2017 13:26:56 -0700 Subject: [PATCH] Static required time computation PiperOrigin-RevId: 164894645 --- .../grappler/optimizers/static_schedule.cc | 66 +++++++++++++++++++ .../grappler/optimizers/static_schedule.h | 11 +++- .../optimizers/static_schedule_test.cc | 46 +++++++++++-- 3 files changed, 118 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/static_schedule.cc b/tensorflow/core/grappler/optimizers/static_schedule.cc index e31499eac66..143cc2d703d 100644 --- a/tensorflow/core/grappler/optimizers/static_schedule.cc +++ b/tensorflow/core/grappler/optimizers/static_schedule.cc @@ -119,5 +119,71 @@ Status EstimateEarliestExecutionTimes( return Status::OK(); } +Status EstimateRequiredTimes( + const GrapplerItem& item, const Cluster* cluster, + const std::unordered_map& + execution_times, + std::unordered_map* required_times) { + std::unordered_map name_map; + for (const NodeDef& node : item.graph.node()) { + name_map[node.name()] = &node; + (*required_times)[&node] = Costs::NanoSeconds::max(); + } + + std::unordered_map pending_fanouts; + for (const NodeDef& node : item.graph.node()) { + for (const string& input : node.input()) { + string node_name = NodeName(input); + auto it = name_map.find(node_name); + if (it == name_map.end()) { + return errors::InvalidArgument( + strings::StrCat("Unknown input node ", input)); + } + const NodeDef* fanin = it->second; + pending_fanouts[fanin] += 1; + } + } + std::deque ready_nodes; + for (const NodeDef& node : item.graph.node()) { + if (pending_fanouts[&node] == 0) { + auto it = execution_times.find(&node); + if (it != execution_times.end()) { + (*required_times)[&node] = it->second; + } + ready_nodes.push_back(&node); + } + } + GraphProperties properties(item); + TF_RETURN_IF_ERROR(properties.InferStatically()); + OpLevelCostEstimator estimator; + VirtualPlacer placer(cluster); + + while (!ready_nodes.empty()) { + const NodeDef* node = ready_nodes.front(); + ready_nodes.pop_front(); + + Costs::NanoSeconds execution_time = + PredictExecutionTime(properties, estimator, placer, *node); + Costs::NanoSeconds required_time = (*required_times)[node] - execution_time; + + for (const string& fanin_name : node->input()) { + const NodeDef* fanin = name_map[NodeName(fanin_name)]; + (*required_times)[fanin] = + std::min((*required_times)[fanin], required_time); + + int pending = pending_fanouts[fanin]; + if (pending == 0) { + // Already processed. Avoid going through loops more than once. + continue; + } else if (pending == 1) { + ready_nodes.push_back(fanin); + } + pending_fanouts[fanin]--; + } + } + + return Status::OK(); +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/static_schedule.h b/tensorflow/core/grappler/optimizers/static_schedule.h index 0dd82b0dab1..aa2726a2bdf 100644 --- a/tensorflow/core/grappler/optimizers/static_schedule.h +++ b/tensorflow/core/grappler/optimizers/static_schedule.h @@ -26,7 +26,7 @@ limitations under the License. namespace tensorflow { namespace grappler { -// Compute the earliest time as which the execution of each node in the graph +// Compute the earliest time at which the execution of each node in the graph // can complete. // In our estimation, we ensure that each node takes at least one nanosecond to // execute: therefore the execution times can be used to derive a topological @@ -35,6 +35,15 @@ Status EstimateEarliestExecutionTimes( const GrapplerItem& item, const Cluster* cluster, std::unordered_map* execution_times); +// Compute the time by which the execution of each node must complete to ensure +// the subsequent nodes can still be executed by the times predicted by the +// EstimateEarliestExecutionTimes function. +Status EstimateRequiredTimes( + const GrapplerItem& item, const Cluster* cluster, + const std::unordered_map& + execution_times, + std::unordered_map* required_times); + } // namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/static_schedule_test.cc b/tensorflow/core/grappler/optimizers/static_schedule_test.cc index 95a745be21a..5de59335872 100644 --- a/tensorflow/core/grappler/optimizers/static_schedule_test.cc +++ b/tensorflow/core/grappler/optimizers/static_schedule_test.cc @@ -65,13 +65,13 @@ TEST_F(StaticScheduleTest, BasicGraph) { EXPECT_EQ(Costs::NanoSeconds(1), time.second); } else if (time.first->name() == "x") { EXPECT_EQ(Costs::NanoSeconds(250002), time.second); - } else if (time.first->name() == "AddN") { + } else if (time.first->name() == "Square") { EXPECT_EQ(Costs::NanoSeconds(1500005), time.second); - } else if (time.first->name() == "AddN_1") { + } else if (time.first->name() == "Square_1") { EXPECT_EQ(Costs::NanoSeconds(2750008), time.second); - } else if (time.first->name() == "AddN_2") { + } else if (time.first->name() == "Square_2") { EXPECT_EQ(Costs::NanoSeconds(4000011), time.second); - } else if (time.first->name() == "AddN_3") { + } else if (time.first->name() == "Square_3") { EXPECT_EQ(Costs::NanoSeconds(5250014), time.second); } else if (time.first->name() == "y") { EXPECT_EQ(Costs::NanoSeconds(6500017), time.second); @@ -121,6 +121,44 @@ TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) { } } +TEST_F(StaticScheduleTest, RequiredTimes) { + // This trivial graph is so basic there's nothing to prune. + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + std::unique_ptr cluster(CreateVirtualCluster()); + + std::unordered_map execution_times; + for (const NodeDef& node : item.graph.node()) { + execution_times[&node] = 0; + } + std::unordered_map required_times; + Status status = EstimateRequiredTimes(item, cluster.get(), execution_times, + &required_times); + TF_EXPECT_OK(status); + + EXPECT_EQ(item.graph.node_size(), required_times.size()); + + for (auto time : required_times) { + if (time.first->name() == "Const/Const") { + EXPECT_EQ(Costs::NanoSeconds(-6500016), time.second); + } else if (time.first->name() == "x") { + EXPECT_EQ(Costs::NanoSeconds(-6250015), time.second); + } else if (time.first->name() == "Square") { + EXPECT_EQ(Costs::NanoSeconds(-5000012), time.second); + } else if (time.first->name() == "Square_1") { + EXPECT_EQ(Costs::NanoSeconds(-3750009), time.second); + } else if (time.first->name() == "Square_2") { + EXPECT_EQ(Costs::NanoSeconds(-2500006), time.second); + } else if (time.first->name() == "Square_3") { + EXPECT_EQ(Costs::NanoSeconds(-1250003), time.second); + } else if (time.first->name() == "y") { + EXPECT_EQ(Costs::NanoSeconds(0), time.second); + } + } +} + } // namespace } // namespace grappler } // namespace tensorflow