From 762b84bef3abe29e9c6a5b3dae3a22f1d44d1c4b Mon Sep 17 00:00:00 2001 From: Ayush Dubey Date: Tue, 19 Nov 2019 11:10:43 -0800 Subject: [PATCH] Add NodeMap update for incoming control edges in ScopedAllocatorOptimizer. Ops that are being fused by the scoped allocator optimizer may have control edges that need to be retained once the ops are fused. The expectation is that in the rewired graph, the incoming control edges will be attached to the _ScopedAllocatorConcat node, and the outcoming control edges will be attached to the _ScopedAllocatorSplit node. This change adds a missing update to NodeMap while rewiring control edges. It also adds tests for confirming correct control edge rewiring. PiperOrigin-RevId: 281334731 Change-Id: I9b7f8ca4ccfc7d66657d0ee6a69ad5562083ccb7 --- tensorflow/core/grappler/optimizers/BUILD | 1 + .../optimizers/scoped_allocator_optimizer.cc | 12 +- .../scoped_allocator_optimizer_test.cc | 125 +++++++++++++++++- 3 files changed, 127 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 80e6bcf31f3..7a97b3df80a 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -923,6 +923,7 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "//tensorflow/core/grappler/utils:topological_sort", ], ) diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc index 12018a56f6a..818bbc5e57a 100644 --- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc @@ -575,7 +575,8 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { const TensorShape& sa_shape, std::vector* sac_inputs) { VLOG(2) << "BuildSAConcatNode " << sac_name; - std::set sac_ctl_inputs; + // control input: edge name -> source node name + absl::flat_hash_map sac_ctl_inputs; for (int i = 0; i < ops.size(); ++i) { NodeDef* old_op = ops[i]; for (const string& old_op_input : old_op->input()) { @@ -584,7 +585,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { if (position == -1) { // A control input: drop if from another member of the op set. if (op_instance_names.find(old_op_input) == op_instance_names.end()) { - sac_ctl_inputs.insert(old_op_input); + sac_ctl_inputs.emplace(old_op_input, input_name); } } else { // TODO(tucker): remove redundant check. @@ -620,8 +621,11 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter { node_map->AddOutput(sa_name, sac_name); // Attach the old control inputs to the new sac node. - for (const string& ctl_input : sac_ctl_inputs) { - sac_node->add_input(ctl_input); + for (const auto& ctl_input : sac_ctl_inputs) { + const auto& ctl_edge = ctl_input.first; + const auto& input_name = ctl_input.second; + sac_node->add_input(ctl_edge); + node_map->AddOutput(input_name, sac_node->name()); } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc index 7cab416a4ab..12eefe79808 100644 --- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/test.h" @@ -147,6 +148,46 @@ class ScopedAllocatorOptimizerTest : public ::testing::Test { TF_CHECK_OK(s.ToGraphDef(graph_def)); } + // Constructs the following graph. + // + // a and b are data inputs. ctl1 and ctl2 are control inputs. a1 and a2 are + // Abs ops. o1 and o2 are data outputs. a1 -> ctl3 and a2 -> ctl4 are + // control edges. + // + // After the optimizer runs, we expect the ctl1 and ctl2 to be connected to + // the SAConcat node, and ctl3 and ctl4 to be connected to SASplit node. + /* + a ctl1 b ctl2 + \ / \ / + a1 a2 + / \ / \ + o1 ctl3 o2 ctl4 + */ + void BuildAbsGraphWithInputAndOutputControlEdges(GraphDef* graph_def) { + Scope s = Scope::NewRootScope(); + s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0"); + + Output a = + ops::Const(s.WithOpName("a"), {0.0, 0.0, 0.0, 0.0}, {2, 2}); + Output b = + ops::Const(s.WithOpName("b"), {0.0, 0.0, 0.0, 0.0}, {2, 2}); + Output ctl1 = + ops::Const(s.WithOpName("ctl1"), {0.0, 0.0, 0.0, 0.0}, {2, 2}); + Output ctl2 = + ops::Const(s.WithOpName("ctl2"), {0.0, 0.0, 0.0, 0.0}, {2, 2}); + Output a1 = ops::Abs(s.WithOpName("a1").WithControlDependencies({ctl1}), a); + Output a2 = ops::Abs(s.WithOpName("a2").WithControlDependencies({ctl2}), b); + Output o1 = ops::Reshape(s.WithOpName("o1"), a1, {1, 4}); + Output o2 = ops::Reshape(s.WithOpName("o2"), a2, {4, 1}); + Output ctl3 = + ops::Const(s.WithOpName("ctl3").WithControlDependencies({a1}), + {0.0, 0.0, 0.0, 0.0}, {2, 2}); + Output ctl4 = + ops::Const(s.WithOpName("ctl4").WithControlDependencies({a2}), + {0.0, 0.0, 0.0, 0.0}, {2, 2}); + TF_CHECK_OK(s.ToGraphDef(graph_def)); + } + // Constructs the following graph. // // We have 2 different name scopes in this graph. s3, a3, a4, r3, and r4 are @@ -247,26 +288,43 @@ class ScopedAllocatorOptimizerTest : public ::testing::Test { } } + void GetNode(NodeMap* node_map, const string& node_name, NodeDef** node_def) { + *node_def = node_map->GetNode(node_name); + ASSERT_TRUE(*node_def); + } + // Validate that a node has a single control input from scoped allocator node. // Return the scoped allocator node. NodeDef* ValidateSAControlInput(GraphDef* graph, NodeMap* node_map, const string& node_name) { - NodeDef* node = node_map->GetNode(node_name); - EXPECT_TRUE(node); + NodeDef* node = nullptr; + GetNode(node_map, node_name, &node); int num_control_inputs = 0; string control_input_name; for (const auto& input : node->input()) { - if (input[0] == '^') { + if (IsControlInput(input)) { ++num_control_inputs; control_input_name = input; } } EXPECT_EQ(num_control_inputs, 1); - NodeDef* control_input_node = node_map->GetNode(control_input_name); - EXPECT_TRUE(control_input_node); + NodeDef* control_input_node = nullptr; + GetNode(node_map, control_input_name, &control_input_node); EXPECT_EQ(control_input_node->op(), "_ScopedAllocator"); return control_input_node; } + + int NumControlInputs(NodeMap* node_map, const string& node_name) { + NodeDef* node = nullptr; + GetNode(node_map, node_name, &node); + int num_control_inputs = 0; + for (const auto& input : node->input()) { + if (IsControlInput(input)) { + ++num_control_inputs; + } + } + return num_control_inputs; + } }; TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) { @@ -287,8 +345,8 @@ TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) { // Examine the resulting graph def. NodeMap node_map(&optimized_graph); - NodeDef* nd = node_map.GetNode("scoped_allocator_1_1"); - ASSERT_TRUE(nd); + NodeDef* nd = nullptr; + GetNode(&node_map, "scoped_allocator_1_1", &nd); { auto& nd_set = node_map.GetOutputs(nd->name()); ASSERT_EQ(3, nd_set.size()); @@ -420,6 +478,59 @@ TEST_F(ScopedAllocatorOptimizerTest, InputDependencies) { EXPECT_EQ(scoped_allocator_node->input(0), "^c"); } +// Test that graphs with input and output control edges are rewired correctly by +// the optimizer. +TEST_F(ScopedAllocatorOptimizerTest, ControlEdgeRewire) { + GrapplerItem item; + BuildAbsGraphWithInputAndOutputControlEdges(&item.graph); + SetShapes(&item.graph); + LOG(INFO) << item.graph.DebugString(); + + ScopedAllocatorOptions opts; + opts.add_enable_op("Abs"); + ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts); + ScopedAllocatorOptimizer::OpNameSet ons; + ons.insert("Const"); + + GraphDef optimized_graph; + TF_ASSERT_OK(sao.Optimize(/*cluster=*/nullptr, item, &optimized_graph)); + TF_ASSERT_OK(TopologicalSort(&optimized_graph)); + NodeMap node_map(&optimized_graph); + LOG(INFO) << optimized_graph.DebugString(); + + // Check that ctl1 and ctl2 are now connected only to SAConcat. + NodeDef* ctl1 = nullptr; + GetNode(&node_map, "ctl1", &ctl1); + const auto& ctl1_outputs = node_map.GetOutputs("ctl1"); + EXPECT_EQ(ctl1_outputs.size(), 1); + NodeDef* sa_concat = *ctl1_outputs.begin(); + EXPECT_EQ(sa_concat->op(), "_ScopedAllocatorConcat"); + NodeDef* ctl2 = nullptr; + GetNode(&node_map, "ctl2", &ctl2); + const auto& ctl2_outputs = node_map.GetOutputs("ctl2"); + EXPECT_EQ(ctl2_outputs.size(), 1); + EXPECT_EQ(*ctl2_outputs.begin(), sa_concat); + + // Check that SAConcat has only 2 input control edges. + EXPECT_EQ(NumControlInputs(&node_map, sa_concat->name()), 2); + + // Check that fused node, which conceptually used to have control inputs from + // ctl1 and ctl2 respectively, no longer has any control inputs. + const auto& sa_concat_outputs = node_map.GetOutputs(sa_concat->name()); + EXPECT_EQ(sa_concat_outputs.size(), 1); + NodeDef* fused_abs = *sa_concat_outputs.begin(); + EXPECT_EQ(NumControlInputs(&node_map, fused_abs->name()), 0); + + // Check that SASplit node has control edges to ctl3, ctl4; also check that + // those are the only control inputs on ctl3 and ctl4. + const auto& fused_abs_outputs = node_map.GetOutputs(fused_abs->name()); + EXPECT_EQ(fused_abs_outputs.size(), 1); + NodeDef* sa_split = *fused_abs_outputs.begin(); + EXPECT_EQ(NumControlOutputs(*sa_split, node_map), 2); + EXPECT_EQ(NumControlInputs(&node_map, "ctl3"), 1); + EXPECT_EQ(NumControlInputs(&node_map, "ctl4"), 1); +} + } // namespace } // namespace grappler } // namespace tensorflow