Add NodeMap update for incoming control edges in ScopedAllocatorOptimizer.

Ops that are being fused by the scoped allocator optimizer may have control
edges that need to be retained once the ops are fused.  The expectation is that
in the rewired graph, the incoming control edges will be attached to the
_ScopedAllocatorConcat node, and the outcoming control edges will be attached
to the _ScopedAllocatorSplit node.

This change adds a missing update to NodeMap while rewiring control edges.  It
also adds tests for confirming correct control edge rewiring.

PiperOrigin-RevId: 281334731
Change-Id: I9b7f8ca4ccfc7d66657d0ee6a69ad5562083ccb7
This commit is contained in:
Ayush Dubey 2019-11-19 11:10:43 -08:00 committed by TensorFlower Gardener
parent ee9f39459d
commit 762b84bef3
3 changed files with 127 additions and 11 deletions

View File

@ -923,6 +923,7 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils", "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
"//tensorflow/core/grappler/utils:topological_sort",
], ],
) )

View File

@ -575,7 +575,8 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
const TensorShape& sa_shape, const TensorShape& sa_shape,
std::vector<NodeDefBuilder::NodeOut>* sac_inputs) { std::vector<NodeDefBuilder::NodeOut>* sac_inputs) {
VLOG(2) << "BuildSAConcatNode " << sac_name; VLOG(2) << "BuildSAConcatNode " << sac_name;
std::set<string> sac_ctl_inputs; // control input: edge name -> source node name
absl::flat_hash_map<string, string> sac_ctl_inputs;
for (int i = 0; i < ops.size(); ++i) { for (int i = 0; i < ops.size(); ++i) {
NodeDef* old_op = ops[i]; NodeDef* old_op = ops[i];
for (const string& old_op_input : old_op->input()) { for (const string& old_op_input : old_op->input()) {
@ -584,7 +585,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
if (position == -1) { if (position == -1) {
// A control input: drop if from another member of the op set. // A control input: drop if from another member of the op set.
if (op_instance_names.find(old_op_input) == op_instance_names.end()) { if (op_instance_names.find(old_op_input) == op_instance_names.end()) {
sac_ctl_inputs.insert(old_op_input); sac_ctl_inputs.emplace(old_op_input, input_name);
} }
} else { } else {
// TODO(tucker): remove redundant check. // TODO(tucker): remove redundant check.
@ -620,8 +621,11 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
node_map->AddOutput(sa_name, sac_name); node_map->AddOutput(sa_name, sac_name);
// Attach the old control inputs to the new sac node. // Attach the old control inputs to the new sac node.
for (const string& ctl_input : sac_ctl_inputs) { for (const auto& ctl_input : sac_ctl_inputs) {
sac_node->add_input(ctl_input); const auto& ctl_edge = ctl_input.first;
const auto& input_name = ctl_input.second;
sac_node->add_input(ctl_edge);
node_map->AddOutput(input_name, sac_node->name());
} }
return Status::OK(); return Status::OK();
} }

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/graph/testlib.h" #include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
@ -147,6 +148,46 @@ class ScopedAllocatorOptimizerTest : public ::testing::Test {
TF_CHECK_OK(s.ToGraphDef(graph_def)); TF_CHECK_OK(s.ToGraphDef(graph_def));
} }
// Constructs the following graph.
//
// a and b are data inputs. ctl1 and ctl2 are control inputs. a1 and a2 are
// Abs ops. o1 and o2 are data outputs. a1 -> ctl3 and a2 -> ctl4 are
// control edges.
//
// After the optimizer runs, we expect the ctl1 and ctl2 to be connected to
// the SAConcat node, and ctl3 and ctl4 to be connected to SASplit node.
/*
a ctl1 b ctl2
\ / \ /
a1 a2
/ \ / \
o1 ctl3 o2 ctl4
*/
void BuildAbsGraphWithInputAndOutputControlEdges(GraphDef* graph_def) {
Scope s = Scope::NewRootScope();
s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
Output a =
ops::Const<float>(s.WithOpName("a"), {0.0, 0.0, 0.0, 0.0}, {2, 2});
Output b =
ops::Const<float>(s.WithOpName("b"), {0.0, 0.0, 0.0, 0.0}, {2, 2});
Output ctl1 =
ops::Const<float>(s.WithOpName("ctl1"), {0.0, 0.0, 0.0, 0.0}, {2, 2});
Output ctl2 =
ops::Const<float>(s.WithOpName("ctl2"), {0.0, 0.0, 0.0, 0.0}, {2, 2});
Output a1 = ops::Abs(s.WithOpName("a1").WithControlDependencies({ctl1}), a);
Output a2 = ops::Abs(s.WithOpName("a2").WithControlDependencies({ctl2}), b);
Output o1 = ops::Reshape(s.WithOpName("o1"), a1, {1, 4});
Output o2 = ops::Reshape(s.WithOpName("o2"), a2, {4, 1});
Output ctl3 =
ops::Const<float>(s.WithOpName("ctl3").WithControlDependencies({a1}),
{0.0, 0.0, 0.0, 0.0}, {2, 2});
Output ctl4 =
ops::Const<float>(s.WithOpName("ctl4").WithControlDependencies({a2}),
{0.0, 0.0, 0.0, 0.0}, {2, 2});
TF_CHECK_OK(s.ToGraphDef(graph_def));
}
// Constructs the following graph. // Constructs the following graph.
// //
// We have 2 different name scopes in this graph. s3, a3, a4, r3, and r4 are // We have 2 different name scopes in this graph. s3, a3, a4, r3, and r4 are
@ -247,26 +288,43 @@ class ScopedAllocatorOptimizerTest : public ::testing::Test {
} }
} }
void GetNode(NodeMap* node_map, const string& node_name, NodeDef** node_def) {
*node_def = node_map->GetNode(node_name);
ASSERT_TRUE(*node_def);
}
// Validate that a node has a single control input from scoped allocator node. // Validate that a node has a single control input from scoped allocator node.
// Return the scoped allocator node. // Return the scoped allocator node.
NodeDef* ValidateSAControlInput(GraphDef* graph, NodeMap* node_map, NodeDef* ValidateSAControlInput(GraphDef* graph, NodeMap* node_map,
const string& node_name) { const string& node_name) {
NodeDef* node = node_map->GetNode(node_name); NodeDef* node = nullptr;
EXPECT_TRUE(node); GetNode(node_map, node_name, &node);
int num_control_inputs = 0; int num_control_inputs = 0;
string control_input_name; string control_input_name;
for (const auto& input : node->input()) { for (const auto& input : node->input()) {
if (input[0] == '^') { if (IsControlInput(input)) {
++num_control_inputs; ++num_control_inputs;
control_input_name = input; control_input_name = input;
} }
} }
EXPECT_EQ(num_control_inputs, 1); EXPECT_EQ(num_control_inputs, 1);
NodeDef* control_input_node = node_map->GetNode(control_input_name); NodeDef* control_input_node = nullptr;
EXPECT_TRUE(control_input_node); GetNode(node_map, control_input_name, &control_input_node);
EXPECT_EQ(control_input_node->op(), "_ScopedAllocator"); EXPECT_EQ(control_input_node->op(), "_ScopedAllocator");
return control_input_node; return control_input_node;
} }
int NumControlInputs(NodeMap* node_map, const string& node_name) {
NodeDef* node = nullptr;
GetNode(node_map, node_name, &node);
int num_control_inputs = 0;
for (const auto& input : node->input()) {
if (IsControlInput(input)) {
++num_control_inputs;
}
}
return num_control_inputs;
}
}; };
TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) { TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) {
@ -287,8 +345,8 @@ TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) {
// Examine the resulting graph def. // Examine the resulting graph def.
NodeMap node_map(&optimized_graph); NodeMap node_map(&optimized_graph);
NodeDef* nd = node_map.GetNode("scoped_allocator_1_1"); NodeDef* nd = nullptr;
ASSERT_TRUE(nd); GetNode(&node_map, "scoped_allocator_1_1", &nd);
{ {
auto& nd_set = node_map.GetOutputs(nd->name()); auto& nd_set = node_map.GetOutputs(nd->name());
ASSERT_EQ(3, nd_set.size()); ASSERT_EQ(3, nd_set.size());
@ -420,6 +478,59 @@ TEST_F(ScopedAllocatorOptimizerTest, InputDependencies) {
EXPECT_EQ(scoped_allocator_node->input(0), "^c"); EXPECT_EQ(scoped_allocator_node->input(0), "^c");
} }
// Test that graphs with input and output control edges are rewired correctly by
// the optimizer.
TEST_F(ScopedAllocatorOptimizerTest, ControlEdgeRewire) {
GrapplerItem item;
BuildAbsGraphWithInputAndOutputControlEdges(&item.graph);
SetShapes(&item.graph);
LOG(INFO) << item.graph.DebugString();
ScopedAllocatorOptions opts;
opts.add_enable_op("Abs");
ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts);
ScopedAllocatorOptimizer::OpNameSet ons;
ons.insert("Const");
GraphDef optimized_graph;
TF_ASSERT_OK(sao.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
TF_ASSERT_OK(TopologicalSort(&optimized_graph));
NodeMap node_map(&optimized_graph);
LOG(INFO) << optimized_graph.DebugString();
// Check that ctl1 and ctl2 are now connected only to SAConcat.
NodeDef* ctl1 = nullptr;
GetNode(&node_map, "ctl1", &ctl1);
const auto& ctl1_outputs = node_map.GetOutputs("ctl1");
EXPECT_EQ(ctl1_outputs.size(), 1);
NodeDef* sa_concat = *ctl1_outputs.begin();
EXPECT_EQ(sa_concat->op(), "_ScopedAllocatorConcat");
NodeDef* ctl2 = nullptr;
GetNode(&node_map, "ctl2", &ctl2);
const auto& ctl2_outputs = node_map.GetOutputs("ctl2");
EXPECT_EQ(ctl2_outputs.size(), 1);
EXPECT_EQ(*ctl2_outputs.begin(), sa_concat);
// Check that SAConcat has only 2 input control edges.
EXPECT_EQ(NumControlInputs(&node_map, sa_concat->name()), 2);
// Check that fused node, which conceptually used to have control inputs from
// ctl1 and ctl2 respectively, no longer has any control inputs.
const auto& sa_concat_outputs = node_map.GetOutputs(sa_concat->name());
EXPECT_EQ(sa_concat_outputs.size(), 1);
NodeDef* fused_abs = *sa_concat_outputs.begin();
EXPECT_EQ(NumControlInputs(&node_map, fused_abs->name()), 0);
// Check that SASplit node has control edges to ctl3, ctl4; also check that
// those are the only control inputs on ctl3 and ctl4.
const auto& fused_abs_outputs = node_map.GetOutputs(fused_abs->name());
EXPECT_EQ(fused_abs_outputs.size(), 1);
NodeDef* sa_split = *fused_abs_outputs.begin();
EXPECT_EQ(NumControlOutputs(*sa_split, node_map), 2);
EXPECT_EQ(NumControlInputs(&node_map, "ctl3"), 1);
EXPECT_EQ(NumControlInputs(&node_map, "ctl4"), 1);
}
} // namespace } // namespace
} // namespace grappler } // namespace grappler
} // namespace tensorflow } // namespace tensorflow