diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 7b4ed10e7e5..e557adc2111 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -213,6 +213,7 @@ cc_library( "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:topological_sort", ], ) @@ -231,6 +232,7 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "//tensorflow/core/grappler/utils:topological_sort", ], ) diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc index 1e97d2d8d28..498a3a443f7 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.cc @@ -23,8 +23,10 @@ limitations under the License. #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" @@ -77,6 +79,7 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) { return false; } if (!fetch_nodes_known_ || NumNonControlOutputs(node, *node_map_) > 0) { + // The output values of this node may be needed. return false; } if (IsMerge(node) || IsSwitch(node)) { @@ -203,7 +206,7 @@ void DependencyOptimizer::OptimizeNode(int node_idx, if (num_inputs * num_outputs > num_inputs + num_outputs) { return; } - VLOG(1) << "***** Rerouting input around " << node->name(); + VLOG(1) << "***** Rerouting input around " << node->name(); std::vector input_nodes; for (int i = 0; i < num_inputs; ++i) { NodeDef* tmp = node_map_->GetNode(node->input(i)); @@ -291,6 +294,94 @@ Status DependencyOptimizer::OptimizeDependencies() { return Status::OK(); } +Status DependencyOptimizer::TransitiveReduction() { + // PRECONDITION: optimized_graph_ must be sorted topologically. + const int num_nodes = optimized_graph_->node_size(); + // Set up a compressed version of the graph to save a constant factor in the + // expensive algorithm below. Also cache the set of control outputs and the + // highest index of a target of any control output from each node. + int num_controls = 0; + std::vector> inputs(num_nodes); + std::vector, 2>> control_outputs( + num_nodes); + for (int node_idx = 0; node_idx < num_nodes; ++node_idx) { + const NodeDef& node = optimized_graph_->node(node_idx); + if (ModifiesFrameInfo(node)) { + // Ignore nodes that modify frame info. + continue; + } + for (int input_slot = 0; input_slot < node.input_size(); ++input_slot) { + const string& input = node.input(input_slot); + const NodeDef* input_node = node_map_->GetNode(input); + if (ModifiesFrameInfo(*input_node)) { + // Ignore edges from nodes that modify frame info. + continue; + } + const int input_node_idx = node_to_idx_[input_node]; + inputs[node_idx].push_back(input_node_idx); + if (IsControlInput(input)) { + ++num_controls; + control_outputs[input_node_idx].emplace_back(node_idx, input_slot); + } + } + } + + // Run the longest path in DAG algorithm for each source node that has control + // outputs. If, for any target node of a control output, there exists a path + // of length > 1, we can drop that control dependency. + int num_controls_removed = 0; + std::vector longest_distance(num_nodes); + for (int source = 0; source < num_nodes; ++source) { + int highest_control_target = -1; + for (const auto& control_output : control_outputs[source]) { + if (control_output.first > highest_control_target) { + highest_control_target = control_output.first; + } + } + if (highest_control_target < source) { + continue; + } + std::fill(longest_distance.begin() + source, + longest_distance.begin() + highest_control_target + 1, 0); + for (int target = source + 1; target <= highest_control_target; ++target) { + for (int input : inputs[target]) { + // If the input node is before source in the topo order, no path + // source -> input -> target can exits and we can skip it. + if (input >= source) { + // If source -> input -> target is longer than the longest + // path so far from source -> target, update the longest_distance. + int candidate_longest_distance = longest_distance[input] + 1; + if (candidate_longest_distance > longest_distance[target]) { + longest_distance[target] = candidate_longest_distance; + } + } + } + } + + // If the longest path from the source to the target of a control dependency + // is longer than 1, there exists an alternate path, and we can eliminate + // the control dependency since it is redundant. + for (const auto& control_output : control_outputs[source]) { + const int target = control_output.first; + if (longest_distance[target] > 1) { + const int input_slot = control_output.second; + // We modify the node inplace here. This is safe because there can + // only be one control edge from a given source to a given target. + const NodeDef& source_node = optimized_graph_->node(source); + NodeDef* target_node = optimized_graph_->mutable_node(target); + target_node->mutable_input()->SwapElements( + input_slot, target_node->input_size() - 1); + node_map_->RemoveOutput(source_node.name(), target_node->name()); + target_node->mutable_input()->RemoveLast(); + ++num_controls_removed; + } + } + } + VLOG(1) << "Removed " << num_controls_removed << " out of " << num_controls + << " control dependencies"; + return Status::OK(); +} + void DependencyOptimizer::BuildNodeToIdx() { // Set up &node -> index map. node_to_idx_.clear(); @@ -302,17 +393,35 @@ void DependencyOptimizer::BuildNodeToIdx() { Status DependencyOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { + VLOG(1) << "Graph before optimization:\n" << optimized_graph_->DebugString(); optimized_graph_ = optimized_graph; *optimized_graph_ = item.graph; nodes_to_preserve_ = item.NodesToPreserve(); fetch_nodes_known_ = !item.fetch.empty(); - node_map_.reset(new NodeMap(optimized_graph_)); - BuildNodeToIdx(); - - VLOG(1) << "Graph before optimization:\n" << optimized_graph_->DebugString(); - TF_RETURN_IF_ERROR(OptimizeDependencies()); CleanControlInputs(); + const int num_iterations = opt_level_ == RewriterConfig::AGGRESSIVE ? 2 : 1; + for (int iteration = 0; iteration < num_iterations; ++iteration) { + Status topo_sort_status = TopologicalSort(optimized_graph_); + node_map_.reset(new NodeMap(optimized_graph_)); + BuildNodeToIdx(); + + // Remove redundant control dependencies, iteration 1. + if (opt_level_ == RewriterConfig::AGGRESSIVE) { + if (topo_sort_status.ok()) { + TF_RETURN_IF_ERROR(TransitiveReduction()); + } else { + LOG(ERROR) << topo_sort_status.error_message(); + } + VLOG(1) << "Graph after transitive reduction:\n" + << optimized_graph_->DebugString(); + } + + // Turn nodes without non-control outputs into NoOps, prune NoOps. + TF_RETURN_IF_ERROR(OptimizeDependencies()); + VLOG(1) << "Graph after NoOp conversion & pruning:\n" + << optimized_graph_->DebugString(); + } VLOG(1) << "Graph after optimization:\n" << optimized_graph_->DebugString(); return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer.h b/tensorflow/core/grappler/optimizers/dependency_optimizer.h index f9d4d0b6c2d..3f6f418bee6 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer.h +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer.h @@ -56,6 +56,9 @@ class DependencyOptimizer : public GraphOptimizer { // inserting them in nodes_to_delete. void OptimizeNode(int node_idx, SetVector* nodes_to_simplify, std::set* nodes_to_delete); + // Eliminates redundant control dependencies by computing the transitive + // reduction of the graph. + Status TransitiveReduction(); // Main driver of dependency optimizations. Status OptimizeDependencies(); diff --git a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc index e714f5c0421..d91525f8148 100644 --- a/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -122,21 +123,22 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop) { EXPECT_EQ(item.graph.node_size(), output.node_size()); for (int i = 0; i < item.graph.node_size(); ++i) { - const NodeDef& original = item.graph.node(i); - const NodeDef& optimized = output.node(i); - EXPECT_EQ(original.name(), optimized.name()); - if (original.name() == "add") { - EXPECT_EQ("NoOp", optimized.op()); - } else { - EXPECT_EQ(original.op(), optimized.op()); - } - EXPECT_EQ(original.input_size(), optimized.input_size()); - for (int j = 0; j < original.input_size(); ++j) { - if (original.name() == "add") { - EXPECT_EQ(AsControlDependency(original.input(j)), optimized.input(j)); - } else { - EXPECT_EQ(original.input(j), optimized.input(j)); - } + const NodeDef& node = item.graph.node(i); + if (node.name() == "add") { + EXPECT_EQ("NoOp", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("^x", node.input(0)); + EXPECT_EQ("^y", node.input(1)); + } else if (node.name() == "id1") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^y", node.input(1)); + } else if (node.name() == "id2") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("^x", node.input(1)); } } } @@ -160,6 +162,7 @@ TEST_F(DependencyOptimizerTest, ChangeToNoop_NoFetch) { Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); + TF_CHECK_OK(TopologicalSort(&item.graph)); VerifyGraphsEqual(item.graph, output, __FUNCTION__); } @@ -234,6 +237,27 @@ TEST_F(DependencyOptimizerTest, RemoveNoOps_SingleInputOrOutput) { } } +TEST_F(DependencyOptimizerTest, Transitive_Reduction_Simple) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); + Output x = ops::Square(s.WithOpName("x"), c); + Output id1 = ops::Identity(s.WithOpName("id1"), x); + Output id2 = + ops::Identity(s.WithOpName("id2").WithControlDependencies({x}), id1); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("id2"); + DependencyOptimizer optimizer(RewriterConfig::AGGRESSIVE); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_EQ(4, output.node_size()); + EXPECT_EQ("id2", output.node(3).name()); + EXPECT_EQ(1, output.node(3).input_size()); + EXPECT_EQ("id1", output.node(3).input(0)); +} + } // namespace } // namespace grappler } // namespace tensorflow