Eliminate redundant control dependencies by computing the transitive reduction of the graph G = (V, E). The graph is turned into a DAG by breaking loops. We sort the DAG topologically and apply, at each source of control dependencies, the linear time algorithm for computing longest paths in a DAG. We can eliminate redundant control dependencies when there exists a path of length > 1 from source to target.
Worst case time complexity is O(\sum_{v \in V_c} |{(v, u) : topo(v) < topo(u) <= max(topo(z)) , where (v,z) \in E_c}|), V_c \subset V is the set of nodes with control outputs, E_c \subset E is the set of control edges and topo(u) is the index of node u in the topological ordering of V. ------------------------------------------------------------------------------------- Results on learning/brain/experimental/grappler/data/inceptionv3.meta: Runtime for pass: ~30 ms. Removes 12% of control dependencies, removes 3.7% of nodes. I1201 15:54:44.624856 38255 dependency_optimizer.cc:351] Finished deduping control inputs I1201 15:54:44.673534 38255 dependency_optimizer.cc:354] Finished topo sort I1201 15:54:44.719586 38255 dependency_optimizer.cc:286] Finished compression I1201 15:54:44.729909 38255 dependency_optimizer.cc:334] Finished reduction I1201 15:54:44.729917 38255 dependency_optimizer.cc:337] Removed 519 out of 4325 control dependencies I1201 15:54:44.890641 38255 dependency_optimizer.cc:245] Deleted 499 out of 13535 nodes for deletion. PiperOrigin-RevId: 178289073
This commit is contained in:
parent
488f09179f
commit
89804a9c68
@ -213,6 +213,7 @@ cc_library(
|
|||||||
"//tensorflow/core/grappler:op_types",
|
"//tensorflow/core/grappler:op_types",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core/grappler/costs:graph_properties",
|
"//tensorflow/core/grappler/costs:graph_properties",
|
||||||
|
"//tensorflow/core/grappler/utils:topological_sort",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -231,6 +232,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core/grappler:grappler_item",
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
"//tensorflow/core/grappler:utils",
|
"//tensorflow/core/grappler:utils",
|
||||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||||
|
"//tensorflow/core/grappler/utils:topological_sort",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,8 +23,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/grappler_item.h"
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
#include "tensorflow/core/grappler/op_types.h"
|
#include "tensorflow/core/grappler/op_types.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||||
|
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/lib/strings/strcat.h"
|
#include "tensorflow/core/lib/strings/strcat.h"
|
||||||
#include "tensorflow/core/util/device_name_utils.h"
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
@ -77,6 +79,7 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!fetch_nodes_known_ || NumNonControlOutputs(node, *node_map_) > 0) {
|
if (!fetch_nodes_known_ || NumNonControlOutputs(node, *node_map_) > 0) {
|
||||||
|
// The output values of this node may be needed.
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (IsMerge(node) || IsSwitch(node)) {
|
if (IsMerge(node) || IsSwitch(node)) {
|
||||||
@ -203,7 +206,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
|
|||||||
if (num_inputs * num_outputs > num_inputs + num_outputs) {
|
if (num_inputs * num_outputs > num_inputs + num_outputs) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
VLOG(1) << "***** Rerouting input around " << node->name();
|
VLOG(1) << "***** Rerouting input around " << node->name();
|
||||||
std::vector<NodeDef*> input_nodes;
|
std::vector<NodeDef*> input_nodes;
|
||||||
for (int i = 0; i < num_inputs; ++i) {
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
NodeDef* tmp = node_map_->GetNode(node->input(i));
|
NodeDef* tmp = node_map_->GetNode(node->input(i));
|
||||||
@ -291,6 +294,94 @@ Status DependencyOptimizer::OptimizeDependencies() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status DependencyOptimizer::TransitiveReduction() {
|
||||||
|
// PRECONDITION: optimized_graph_ must be sorted topologically.
|
||||||
|
const int num_nodes = optimized_graph_->node_size();
|
||||||
|
// Set up a compressed version of the graph to save a constant factor in the
|
||||||
|
// expensive algorithm below. Also cache the set of control outputs and the
|
||||||
|
// highest index of a target of any control output from each node.
|
||||||
|
int num_controls = 0;
|
||||||
|
std::vector<gtl::InlinedVector<int, 4>> inputs(num_nodes);
|
||||||
|
std::vector<gtl::InlinedVector<std::pair<int, int>, 2>> control_outputs(
|
||||||
|
num_nodes);
|
||||||
|
for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
|
||||||
|
const NodeDef& node = optimized_graph_->node(node_idx);
|
||||||
|
if (ModifiesFrameInfo(node)) {
|
||||||
|
// Ignore nodes that modify frame info.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
|
||||||
|
const string& input = node.input(input_slot);
|
||||||
|
const NodeDef* input_node = node_map_->GetNode(input);
|
||||||
|
if (ModifiesFrameInfo(*input_node)) {
|
||||||
|
// Ignore edges from nodes that modify frame info.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const int input_node_idx = node_to_idx_[input_node];
|
||||||
|
inputs[node_idx].push_back(input_node_idx);
|
||||||
|
if (IsControlInput(input)) {
|
||||||
|
++num_controls;
|
||||||
|
control_outputs[input_node_idx].emplace_back(node_idx, input_slot);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the longest path in DAG algorithm for each source node that has control
|
||||||
|
// outputs. If, for any target node of a control output, there exists a path
|
||||||
|
// of length > 1, we can drop that control dependency.
|
||||||
|
int num_controls_removed = 0;
|
||||||
|
std::vector<int> longest_distance(num_nodes);
|
||||||
|
for (int source = 0; source < num_nodes; ++source) {
|
||||||
|
int highest_control_target = -1;
|
||||||
|
for (const auto& control_output : control_outputs[source]) {
|
||||||
|
if (control_output.first > highest_control_target) {
|
||||||
|
highest_control_target = control_output.first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (highest_control_target < source) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
std::fill(longest_distance.begin() + source,
|
||||||
|
longest_distance.begin() + highest_control_target + 1, 0);
|
||||||
|
for (int target = source + 1; target <= highest_control_target; ++target) {
|
||||||
|
for (int input : inputs[target]) {
|
||||||
|
// If the input node is before source in the topo order, no path
|
||||||
|
// source -> input -> target can exits and we can skip it.
|
||||||
|
if (input >= source) {
|
||||||
|
// If source -> input -> target is longer than the longest
|
||||||
|
// path so far from source -> target, update the longest_distance.
|
||||||
|
int candidate_longest_distance = longest_distance[input] + 1;
|
||||||
|
if (candidate_longest_distance > longest_distance[target]) {
|
||||||
|
longest_distance[target] = candidate_longest_distance;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the longest path from the source to the target of a control dependency
|
||||||
|
// is longer than 1, there exists an alternate path, and we can eliminate
|
||||||
|
// the control dependency since it is redundant.
|
||||||
|
for (const auto& control_output : control_outputs[source]) {
|
||||||
|
const int target = control_output.first;
|
||||||
|
if (longest_distance[target] > 1) {
|
||||||
|
const int input_slot = control_output.second;
|
||||||
|
// We modify the node inplace here. This is safe because there can
|
||||||
|
// only be one control edge from a given source to a given target.
|
||||||
|
const NodeDef& source_node = optimized_graph_->node(source);
|
||||||
|
NodeDef* target_node = optimized_graph_->mutable_node(target);
|
||||||
|
target_node->mutable_input()->SwapElements(
|
||||||
|
input_slot, target_node->input_size() - 1);
|
||||||
|
node_map_->RemoveOutput(source_node.name(), target_node->name());
|
||||||
|
target_node->mutable_input()->RemoveLast();
|
||||||
|
++num_controls_removed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
|
||||||
|
<< " control dependencies";
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
void DependencyOptimizer::BuildNodeToIdx() {
|
void DependencyOptimizer::BuildNodeToIdx() {
|
||||||
// Set up &node -> index map.
|
// Set up &node -> index map.
|
||||||
node_to_idx_.clear();
|
node_to_idx_.clear();
|
||||||
@ -302,17 +393,35 @@ void DependencyOptimizer::BuildNodeToIdx() {
|
|||||||
|
|
||||||
Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
GraphDef* optimized_graph) {
|
GraphDef* optimized_graph) {
|
||||||
|
VLOG(1) << "Graph before optimization:\n" << optimized_graph_->DebugString();
|
||||||
optimized_graph_ = optimized_graph;
|
optimized_graph_ = optimized_graph;
|
||||||
*optimized_graph_ = item.graph;
|
*optimized_graph_ = item.graph;
|
||||||
nodes_to_preserve_ = item.NodesToPreserve();
|
nodes_to_preserve_ = item.NodesToPreserve();
|
||||||
fetch_nodes_known_ = !item.fetch.empty();
|
fetch_nodes_known_ = !item.fetch.empty();
|
||||||
node_map_.reset(new NodeMap(optimized_graph_));
|
|
||||||
BuildNodeToIdx();
|
|
||||||
|
|
||||||
VLOG(1) << "Graph before optimization:\n" << optimized_graph_->DebugString();
|
|
||||||
TF_RETURN_IF_ERROR(OptimizeDependencies());
|
|
||||||
|
|
||||||
CleanControlInputs();
|
CleanControlInputs();
|
||||||
|
const int num_iterations = opt_level_ == RewriterConfig::AGGRESSIVE ? 2 : 1;
|
||||||
|
for (int iteration = 0; iteration < num_iterations; ++iteration) {
|
||||||
|
Status topo_sort_status = TopologicalSort(optimized_graph_);
|
||||||
|
node_map_.reset(new NodeMap(optimized_graph_));
|
||||||
|
BuildNodeToIdx();
|
||||||
|
|
||||||
|
// Remove redundant control dependencies, iteration 1.
|
||||||
|
if (opt_level_ == RewriterConfig::AGGRESSIVE) {
|
||||||
|
if (topo_sort_status.ok()) {
|
||||||
|
TF_RETURN_IF_ERROR(TransitiveReduction());
|
||||||
|
} else {
|
||||||
|
LOG(ERROR) << topo_sort_status.error_message();
|
||||||
|
}
|
||||||
|
VLOG(1) << "Graph after transitive reduction:\n"
|
||||||
|
<< optimized_graph_->DebugString();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Turn nodes without non-control outputs into NoOps, prune NoOps.
|
||||||
|
TF_RETURN_IF_ERROR(OptimizeDependencies());
|
||||||
|
VLOG(1) << "Graph after NoOp conversion & pruning:\n"
|
||||||
|
<< optimized_graph_->DebugString();
|
||||||
|
}
|
||||||
VLOG(1) << "Graph after optimization:\n" << optimized_graph_->DebugString();
|
VLOG(1) << "Graph after optimization:\n" << optimized_graph_->DebugString();
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -56,6 +56,9 @@ class DependencyOptimizer : public GraphOptimizer {
|
|||||||
// inserting them in nodes_to_delete.
|
// inserting them in nodes_to_delete.
|
||||||
void OptimizeNode(int node_idx, SetVector<int>* nodes_to_simplify,
|
void OptimizeNode(int node_idx, SetVector<int>* nodes_to_simplify,
|
||||||
std::set<int>* nodes_to_delete);
|
std::set<int>* nodes_to_delete);
|
||||||
|
// Eliminates redundant control dependencies by computing the transitive
|
||||||
|
// reduction of the graph.
|
||||||
|
Status TransitiveReduction();
|
||||||
// Main driver of dependency optimizations.
|
// Main driver of dependency optimizations.
|
||||||
Status OptimizeDependencies();
|
Status OptimizeDependencies();
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||||
#include "tensorflow/core/grappler/utils.h"
|
#include "tensorflow/core/grappler/utils.h"
|
||||||
|
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
@ -122,21 +123,22 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop) {
|
|||||||
|
|
||||||
EXPECT_EQ(item.graph.node_size(), output.node_size());
|
EXPECT_EQ(item.graph.node_size(), output.node_size());
|
||||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||||
const NodeDef& original = item.graph.node(i);
|
const NodeDef& node = item.graph.node(i);
|
||||||
const NodeDef& optimized = output.node(i);
|
if (node.name() == "add") {
|
||||||
EXPECT_EQ(original.name(), optimized.name());
|
EXPECT_EQ("NoOp", node.op());
|
||||||
if (original.name() == "add") {
|
EXPECT_EQ(2, node.input_size());
|
||||||
EXPECT_EQ("NoOp", optimized.op());
|
EXPECT_EQ("^x", node.input(0));
|
||||||
} else {
|
EXPECT_EQ("^y", node.input(1));
|
||||||
EXPECT_EQ(original.op(), optimized.op());
|
} else if (node.name() == "id1") {
|
||||||
}
|
EXPECT_EQ("Identity", node.op());
|
||||||
EXPECT_EQ(original.input_size(), optimized.input_size());
|
EXPECT_EQ(2, node.input_size());
|
||||||
for (int j = 0; j < original.input_size(); ++j) {
|
EXPECT_EQ("x", node.input(0));
|
||||||
if (original.name() == "add") {
|
EXPECT_EQ("^y", node.input(1));
|
||||||
EXPECT_EQ(AsControlDependency(original.input(j)), optimized.input(j));
|
} else if (node.name() == "id2") {
|
||||||
} else {
|
EXPECT_EQ("Identity", node.op());
|
||||||
EXPECT_EQ(original.input(j), optimized.input(j));
|
EXPECT_EQ(2, node.input_size());
|
||||||
}
|
EXPECT_EQ("y", node.input(0));
|
||||||
|
EXPECT_EQ("^x", node.input(1));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -160,6 +162,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) {
|
|||||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||||
TF_EXPECT_OK(status);
|
TF_EXPECT_OK(status);
|
||||||
|
|
||||||
|
TF_CHECK_OK(TopologicalSort(&item.graph));
|
||||||
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
|
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -234,6 +237,27 @@ TEST_F(DependencyOptimizerTest, RemoveNoOps_SingleInputOrOutput) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
|
||||||
|
Output x = ops::Square(s.WithOpName("x"), c);
|
||||||
|
Output id1 = ops::Identity(s.WithOpName("id1"), x);
|
||||||
|
Output id2 =
|
||||||
|
ops::Identity(s.WithOpName("id2").WithControlDependencies({x}), id1);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
|
item.fetch.push_back("id2");
|
||||||
|
DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
|
||||||
|
GraphDef output;
|
||||||
|
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
EXPECT_EQ(4, output.node_size());
|
||||||
|
EXPECT_EQ("id2", output.node(3).name());
|
||||||
|
EXPECT_EQ(1, output.node(3).input_size());
|
||||||
|
EXPECT_EQ("id1", output.node(3).input(0));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace grappler
|
} // namespace grappler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user