Static required time computation
PiperOrigin-RevId: 164894645
This commit is contained in:
parent
076158f9b9
commit
c8897e9bce
tensorflow/core/grappler/optimizers
@ -119,5 +119,71 @@ Status EstimateEarliestExecutionTimes(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EstimateRequiredTimes(
|
||||
const GrapplerItem& item, const Cluster* cluster,
|
||||
const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
|
||||
execution_times,
|
||||
std::unordered_map<const NodeDef*, Costs::NanoSeconds>* required_times) {
|
||||
std::unordered_map<string, const NodeDef*> name_map;
|
||||
for (const NodeDef& node : item.graph.node()) {
|
||||
name_map[node.name()] = &node;
|
||||
(*required_times)[&node] = Costs::NanoSeconds::max();
|
||||
}
|
||||
|
||||
std::unordered_map<const NodeDef*, int> pending_fanouts;
|
||||
for (const NodeDef& node : item.graph.node()) {
|
||||
for (const string& input : node.input()) {
|
||||
string node_name = NodeName(input);
|
||||
auto it = name_map.find(node_name);
|
||||
if (it == name_map.end()) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Unknown input node ", input));
|
||||
}
|
||||
const NodeDef* fanin = it->second;
|
||||
pending_fanouts[fanin] += 1;
|
||||
}
|
||||
}
|
||||
std::deque<const NodeDef*> ready_nodes;
|
||||
for (const NodeDef& node : item.graph.node()) {
|
||||
if (pending_fanouts[&node] == 0) {
|
||||
auto it = execution_times.find(&node);
|
||||
if (it != execution_times.end()) {
|
||||
(*required_times)[&node] = it->second;
|
||||
}
|
||||
ready_nodes.push_back(&node);
|
||||
}
|
||||
}
|
||||
GraphProperties properties(item);
|
||||
TF_RETURN_IF_ERROR(properties.InferStatically());
|
||||
OpLevelCostEstimator estimator;
|
||||
VirtualPlacer placer(cluster);
|
||||
|
||||
while (!ready_nodes.empty()) {
|
||||
const NodeDef* node = ready_nodes.front();
|
||||
ready_nodes.pop_front();
|
||||
|
||||
Costs::NanoSeconds execution_time =
|
||||
PredictExecutionTime(properties, estimator, placer, *node);
|
||||
Costs::NanoSeconds required_time = (*required_times)[node] - execution_time;
|
||||
|
||||
for (const string& fanin_name : node->input()) {
|
||||
const NodeDef* fanin = name_map[NodeName(fanin_name)];
|
||||
(*required_times)[fanin] =
|
||||
std::min((*required_times)[fanin], required_time);
|
||||
|
||||
int pending = pending_fanouts[fanin];
|
||||
if (pending == 0) {
|
||||
// Already processed. Avoid going through loops more than once.
|
||||
continue;
|
||||
} else if (pending == 1) {
|
||||
ready_nodes.push_back(fanin);
|
||||
}
|
||||
pending_fanouts[fanin]--;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Compute the earliest time as which the execution of each node in the graph
|
||||
// Compute the earliest time at which the execution of each node in the graph
|
||||
// can complete.
|
||||
// In our estimation, we ensure that each node takes at least one nanosecond to
|
||||
// execute: therefore the execution times can be used to derive a topological
|
||||
@ -35,6 +35,15 @@ Status EstimateEarliestExecutionTimes(
|
||||
const GrapplerItem& item, const Cluster* cluster,
|
||||
std::unordered_map<const NodeDef*, Costs::NanoSeconds>* execution_times);
|
||||
|
||||
// Compute the time by which the execution of each node must complete to ensure
|
||||
// the subsequent nodes can still be executed by the times predicted by the
|
||||
// EstimateEarliestExecutionTimes function.
|
||||
Status EstimateRequiredTimes(
|
||||
const GrapplerItem& item, const Cluster* cluster,
|
||||
const std::unordered_map<const NodeDef*, Costs::NanoSeconds>&
|
||||
execution_times,
|
||||
std::unordered_map<const NodeDef*, Costs::NanoSeconds>* required_times);
|
||||
|
||||
} // namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -65,13 +65,13 @@ TEST_F(StaticScheduleTest, BasicGraph) {
|
||||
EXPECT_EQ(Costs::NanoSeconds(1), time.second);
|
||||
} else if (time.first->name() == "x") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(250002), time.second);
|
||||
} else if (time.first->name() == "AddN") {
|
||||
} else if (time.first->name() == "Square") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(1500005), time.second);
|
||||
} else if (time.first->name() == "AddN_1") {
|
||||
} else if (time.first->name() == "Square_1") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(2750008), time.second);
|
||||
} else if (time.first->name() == "AddN_2") {
|
||||
} else if (time.first->name() == "Square_2") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(4000011), time.second);
|
||||
} else if (time.first->name() == "AddN_3") {
|
||||
} else if (time.first->name() == "Square_3") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(5250014), time.second);
|
||||
} else if (time.first->name() == "y") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(6500017), time.second);
|
||||
@ -121,6 +121,44 @@ TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(StaticScheduleTest, RequiredTimes) {
|
||||
// This trivial graph is so basic there's nothing to prune.
|
||||
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
|
||||
GrapplerItem item;
|
||||
CHECK(fake_input.NextItem(&item));
|
||||
|
||||
std::unique_ptr<VirtualCluster> cluster(CreateVirtualCluster());
|
||||
|
||||
std::unordered_map<const NodeDef*, Costs::NanoSeconds> execution_times;
|
||||
for (const NodeDef& node : item.graph.node()) {
|
||||
execution_times[&node] = 0;
|
||||
}
|
||||
std::unordered_map<const NodeDef*, Costs::NanoSeconds> required_times;
|
||||
Status status = EstimateRequiredTimes(item, cluster.get(), execution_times,
|
||||
&required_times);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(item.graph.node_size(), required_times.size());
|
||||
|
||||
for (auto time : required_times) {
|
||||
if (time.first->name() == "Const/Const") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(-6500016), time.second);
|
||||
} else if (time.first->name() == "x") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(-6250015), time.second);
|
||||
} else if (time.first->name() == "Square") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(-5000012), time.second);
|
||||
} else if (time.first->name() == "Square_1") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(-3750009), time.second);
|
||||
} else if (time.first->name() == "Square_2") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(-2500006), time.second);
|
||||
} else if (time.first->name() == "Square_3") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(-1250003), time.second);
|
||||
} else if (time.first->name() == "y") {
|
||||
EXPECT_EQ(Costs::NanoSeconds(0), time.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user