Eliminate redundant control dependencies by computing the transitive reduction of the graph G = (V, E). The graph is turned into a DAG by breaking loops. We sort the DAG topologically and apply, at each source of control dependencies, the linear time algorithm for computing longest paths in a DAG. We can eliminate redundant control dependencies when there exists a path of length > 1 from source to target.

Worst case time complexity is O(\sum_{v \in V_c} |{(v, u) : topo(v) < topo(u) <= max(topo(z)) , where (v,z) \in E_c}|),

V_c \subset V is the set of nodes with control outputs, E_c \subset E is the set of control edges and topo(u) is the index of node u in the topological ordering of V.

-------------------------------------------------------------------------------------

Results on learning/brain/experimental/grappler/data/inceptionv3.meta:

Runtime for pass: ~30 ms.
Removes 12% of control dependencies, removes 3.7% of nodes.

I1201 15:54:44.624856   38255 dependency_optimizer.cc:351] Finished deduping control inputs
I1201 15:54:44.673534   38255 dependency_optimizer.cc:354] Finished topo sort
I1201 15:54:44.719586   38255 dependency_optimizer.cc:286] Finished compression
I1201 15:54:44.729909   38255 dependency_optimizer.cc:334] Finished reduction
I1201 15:54:44.729917   38255 dependency_optimizer.cc:337] Removed 519 out of 4325 control dependencies
I1201 15:54:44.890641   38255 dependency_optimizer.cc:245] Deleted 499 out of 13535 nodes for deletion.

PiperOrigin-RevId: 178289073
This commit is contained in:
A. Unique TensorFlower 2017-12-07 13:48:07 -08:00 committed by TensorFlower Gardener
parent 488f09179f
commit 89804a9c68
4 changed files with 159 additions and 21 deletions

View File

@ -213,6 +213,7 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:topological_sort",
],
)
@ -231,6 +232,7 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
"//tensorflow/core/grappler/utils:topological_sort",
],
)

View File

@ -23,8 +23,10 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/util/device_name_utils.h"
@ -77,6 +79,7 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) {
return false;
}
if (!fetch_nodes_known_ || NumNonControlOutputs(node, *node_map_) > 0) {
// The output values of this node may be needed.
return false;
}
if (IsMerge(node) || IsSwitch(node)) {
@ -291,6 +294,94 @@ Status DependencyOptimizer::OptimizeDependencies() {
return Status::OK();
}
Status DependencyOptimizer::TransitiveReduction() {
// PRECONDITION: optimized_graph_ must be sorted topologically.
const int num_nodes = optimized_graph_->node_size();
// Set up a compressed version of the graph to save a constant factor in the
// expensive algorithm below. Also cache the set of control outputs and the
// highest index of a target of any control output from each node.
int num_controls = 0;
std::vector<gtl::InlinedVector<int, 4>> inputs(num_nodes);
std::vector<gtl::InlinedVector<std::pair<int, int>, 2>> control_outputs(
num_nodes);
for (int node_idx = 0; node_idx < num_nodes; ++node_idx) {
const NodeDef& node = optimized_graph_->node(node_idx);
if (ModifiesFrameInfo(node)) {
// Ignore nodes that modify frame info.
continue;
}
for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) {
const string& input = node.input(input_slot);
const NodeDef* input_node = node_map_->GetNode(input);
if (ModifiesFrameInfo(*input_node)) {
// Ignore edges from nodes that modify frame info.
continue;
}
const int input_node_idx = node_to_idx_[input_node];
inputs[node_idx].push_back(input_node_idx);
if (IsControlInput(input)) {
++num_controls;
control_outputs[input_node_idx].emplace_back(node_idx, input_slot);
}
}
}
// Run the longest path in DAG algorithm for each source node that has control
// outputs. If, for any target node of a control output, there exists a path
// of length > 1, we can drop that control dependency.
int num_controls_removed = 0;
std::vector<int> longest_distance(num_nodes);
for (int source = 0; source < num_nodes; ++source) {
int highest_control_target = -1;
for (const auto& control_output : control_outputs[source]) {
if (control_output.first > highest_control_target) {
highest_control_target = control_output.first;
}
}
if (highest_control_target < source) {
continue;
}
std::fill(longest_distance.begin() + source,
longest_distance.begin() + highest_control_target + 1, 0);
for (int target = source + 1; target <= highest_control_target; ++target) {
for (int input : inputs[target]) {
// If the input node is before source in the topo order, no path
// source -> input -> target can exits and we can skip it.
if (input >= source) {
// If source -> input -> target is longer than the longest
// path so far from source -> target, update the longest_distance.
int candidate_longest_distance = longest_distance[input] + 1;
if (candidate_longest_distance > longest_distance[target]) {
longest_distance[target] = candidate_longest_distance;
}
}
}
}
// If the longest path from the source to the target of a control dependency
// is longer than 1, there exists an alternate path, and we can eliminate
// the control dependency since it is redundant.
for (const auto& control_output : control_outputs[source]) {
const int target = control_output.first;
if (longest_distance[target] > 1) {
const int input_slot = control_output.second;
// We modify the node inplace here. This is safe because there can
// only be one control edge from a given source to a given target.
const NodeDef& source_node = optimized_graph_->node(source);
NodeDef* target_node = optimized_graph_->mutable_node(target);
target_node->mutable_input()->SwapElements(
input_slot, target_node->input_size() - 1);
node_map_->RemoveOutput(source_node.name(), target_node->name());
target_node->mutable_input()->RemoveLast();
++num_controls_removed;
}
}
}
VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls
<< " control dependencies";
return Status::OK();
}
void DependencyOptimizer::BuildNodeToIdx() {
// Set up &node -> index map.
node_to_idx_.clear();
@ -302,17 +393,35 @@ void DependencyOptimizer::BuildNodeToIdx() {
Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
VLOG(1) << "Graph before optimization:\n" << optimized_graph_->DebugString();
optimized_graph_ = optimized_graph;
*optimized_graph_ = item.graph;
nodes_to_preserve_ = item.NodesToPreserve();
fetch_nodes_known_ = !item.fetch.empty();
CleanControlInputs();
const int num_iterations = opt_level_ == RewriterConfig::AGGRESSIVE ? 2 : 1;
for (int iteration = 0; iteration < num_iterations; ++iteration) {
Status topo_sort_status = TopologicalSort(optimized_graph_);
node_map_.reset(new NodeMap(optimized_graph_));
BuildNodeToIdx();
VLOG(1) << "Graph before optimization:\n" << optimized_graph_->DebugString();
TF_RETURN_IF_ERROR(OptimizeDependencies());
// Remove redundant control dependencies, iteration 1.
if (opt_level_ == RewriterConfig::AGGRESSIVE) {
if (topo_sort_status.ok()) {
TF_RETURN_IF_ERROR(TransitiveReduction());
} else {
LOG(ERROR) << topo_sort_status.error_message();
}
VLOG(1) << "Graph after transitive reduction:\n"
<< optimized_graph_->DebugString();
}
CleanControlInputs();
// Turn nodes without non-control outputs into NoOps, prune NoOps.
TF_RETURN_IF_ERROR(OptimizeDependencies());
VLOG(1) << "Graph after NoOp conversion & pruning:\n"
<< optimized_graph_->DebugString();
}
VLOG(1) << "Graph after optimization:\n" << optimized_graph_->DebugString();
return Status::OK();

View File

@ -56,6 +56,9 @@ class DependencyOptimizer : public GraphOptimizer {
// inserting them in nodes_to_delete.
void OptimizeNode(int node_idx, SetVector<int>* nodes_to_simplify,
std::set<int>* nodes_to_delete);
// Eliminates redundant control dependencies by computing the transitive
// reduction of the graph.
Status TransitiveReduction();
// Main driver of dependency optimizations.
Status OptimizeDependencies();

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@ -122,21 +123,22 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop) {
EXPECT_EQ(item.graph.node_size(), output.node_size());
for (int i = 0; i < item.graph.node_size(); ++i) {
const NodeDef& original = item.graph.node(i);
const NodeDef& optimized = output.node(i);
EXPECT_EQ(original.name(), optimized.name());
if (original.name() == "add") {
EXPECT_EQ("NoOp", optimized.op());
} else {
EXPECT_EQ(original.op(), optimized.op());
}
EXPECT_EQ(original.input_size(), optimized.input_size());
for (int j = 0; j < original.input_size(); ++j) {
if (original.name() == "add") {
EXPECT_EQ(AsControlDependency(original.input(j)), optimized.input(j));
} else {
EXPECT_EQ(original.input(j), optimized.input(j));
}
const NodeDef& node = item.graph.node(i);
if (node.name() == "add") {
EXPECT_EQ("NoOp", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("^x", node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (node.name() == "id1") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (node.name() == "id2") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ(2, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("^x", node.input(1));
}
}
}
@ -160,6 +162,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) {
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
TF_CHECK_OK(TopologicalSort(&item.graph));
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
}
@ -234,6 +237,27 @@ TEST_F(DependencyOptimizerTest, RemoveNoOps_SingleInputOrOutput) {
}
}
TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
Output x = ops::Square(s.WithOpName("x"), c);
Output id1 = ops::Identity(s.WithOpName("id1"), x);
Output id2 =
ops::Identity(s.WithOpName("id2").WithControlDependencies({x}), id1);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch.push_back("id2");
DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(4, output.node_size());
EXPECT_EQ("id2", output.node(3).name());
EXPECT_EQ(1, output.node(3).input_size());
EXPECT_EQ("id1", output.node(3).input(0));
}
} // namespace
} // namespace grappler
} // namespace tensorflow