Add NodeMap update for incoming control edges in ScopedAllocatorOptimizer.
Ops that are being fused by the scoped allocator optimizer may have control edges that need to be retained once the ops are fused. The expectation is that in the rewired graph, the incoming control edges will be attached to the _ScopedAllocatorConcat node, and the outcoming control edges will be attached to the _ScopedAllocatorSplit node. This change adds a missing update to NodeMap while rewiring control edges. It also adds tests for confirming correct control edge rewiring. PiperOrigin-RevId: 281334731 Change-Id: I9b7f8ca4ccfc7d66657d0ee6a69ad5562083ccb7
This commit is contained in:
parent
ee9f39459d
commit
762b84bef3
|
@ -923,6 +923,7 @@ tf_cc_test(
|
|||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -575,7 +575,8 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
|
|||
const TensorShape& sa_shape,
|
||||
std::vector<NodeDefBuilder::NodeOut>* sac_inputs) {
|
||||
VLOG(2) << "BuildSAConcatNode " << sac_name;
|
||||
std::set<string> sac_ctl_inputs;
|
||||
// control input: edge name -> source node name
|
||||
absl::flat_hash_map<string, string> sac_ctl_inputs;
|
||||
for (int i = 0; i < ops.size(); ++i) {
|
||||
NodeDef* old_op = ops[i];
|
||||
for (const string& old_op_input : old_op->input()) {
|
||||
|
@ -584,7 +585,7 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
|
|||
if (position == -1) {
|
||||
// A control input: drop if from another member of the op set.
|
||||
if (op_instance_names.find(old_op_input) == op_instance_names.end()) {
|
||||
sac_ctl_inputs.insert(old_op_input);
|
||||
sac_ctl_inputs.emplace(old_op_input, input_name);
|
||||
}
|
||||
} else {
|
||||
// TODO(tucker): remove redundant check.
|
||||
|
@ -620,8 +621,11 @@ class UnaryElementwiseRewriter : public ScopedAllocatorOptimizer::Rewriter {
|
|||
node_map->AddOutput(sa_name, sac_name);
|
||||
|
||||
// Attach the old control inputs to the new sac node.
|
||||
for (const string& ctl_input : sac_ctl_inputs) {
|
||||
sac_node->add_input(ctl_input);
|
||||
for (const auto& ctl_input : sac_ctl_inputs) {
|
||||
const auto& ctl_edge = ctl_input.first;
|
||||
const auto& input_name = ctl_input.second;
|
||||
sac_node->add_input(ctl_edge);
|
||||
node_map->AddOutput(input_name, sac_node->name());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
@ -147,6 +148,46 @@ class ScopedAllocatorOptimizerTest : public ::testing::Test {
|
|||
TF_CHECK_OK(s.ToGraphDef(graph_def));
|
||||
}
|
||||
|
||||
// Constructs the following graph.
|
||||
//
|
||||
// a and b are data inputs. ctl1 and ctl2 are control inputs. a1 and a2 are
|
||||
// Abs ops. o1 and o2 are data outputs. a1 -> ctl3 and a2 -> ctl4 are
|
||||
// control edges.
|
||||
//
|
||||
// After the optimizer runs, we expect the ctl1 and ctl2 to be connected to
|
||||
// the SAConcat node, and ctl3 and ctl4 to be connected to SASplit node.
|
||||
/*
|
||||
a ctl1 b ctl2
|
||||
\ / \ /
|
||||
a1 a2
|
||||
/ \ / \
|
||||
o1 ctl3 o2 ctl4
|
||||
*/
|
||||
void BuildAbsGraphWithInputAndOutputControlEdges(GraphDef* graph_def) {
|
||||
Scope s = Scope::NewRootScope();
|
||||
s = s.WithDevice("/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
|
||||
Output a =
|
||||
ops::Const<float>(s.WithOpName("a"), {0.0, 0.0, 0.0, 0.0}, {2, 2});
|
||||
Output b =
|
||||
ops::Const<float>(s.WithOpName("b"), {0.0, 0.0, 0.0, 0.0}, {2, 2});
|
||||
Output ctl1 =
|
||||
ops::Const<float>(s.WithOpName("ctl1"), {0.0, 0.0, 0.0, 0.0}, {2, 2});
|
||||
Output ctl2 =
|
||||
ops::Const<float>(s.WithOpName("ctl2"), {0.0, 0.0, 0.0, 0.0}, {2, 2});
|
||||
Output a1 = ops::Abs(s.WithOpName("a1").WithControlDependencies({ctl1}), a);
|
||||
Output a2 = ops::Abs(s.WithOpName("a2").WithControlDependencies({ctl2}), b);
|
||||
Output o1 = ops::Reshape(s.WithOpName("o1"), a1, {1, 4});
|
||||
Output o2 = ops::Reshape(s.WithOpName("o2"), a2, {4, 1});
|
||||
Output ctl3 =
|
||||
ops::Const<float>(s.WithOpName("ctl3").WithControlDependencies({a1}),
|
||||
{0.0, 0.0, 0.0, 0.0}, {2, 2});
|
||||
Output ctl4 =
|
||||
ops::Const<float>(s.WithOpName("ctl4").WithControlDependencies({a2}),
|
||||
{0.0, 0.0, 0.0, 0.0}, {2, 2});
|
||||
TF_CHECK_OK(s.ToGraphDef(graph_def));
|
||||
}
|
||||
|
||||
// Constructs the following graph.
|
||||
//
|
||||
// We have 2 different name scopes in this graph. s3, a3, a4, r3, and r4 are
|
||||
|
@ -247,26 +288,43 @@ class ScopedAllocatorOptimizerTest : public ::testing::Test {
|
|||
}
|
||||
}
|
||||
|
||||
void GetNode(NodeMap* node_map, const string& node_name, NodeDef** node_def) {
|
||||
*node_def = node_map->GetNode(node_name);
|
||||
ASSERT_TRUE(*node_def);
|
||||
}
|
||||
|
||||
// Validate that a node has a single control input from scoped allocator node.
|
||||
// Return the scoped allocator node.
|
||||
NodeDef* ValidateSAControlInput(GraphDef* graph, NodeMap* node_map,
|
||||
const string& node_name) {
|
||||
NodeDef* node = node_map->GetNode(node_name);
|
||||
EXPECT_TRUE(node);
|
||||
NodeDef* node = nullptr;
|
||||
GetNode(node_map, node_name, &node);
|
||||
int num_control_inputs = 0;
|
||||
string control_input_name;
|
||||
for (const auto& input : node->input()) {
|
||||
if (input[0] == '^') {
|
||||
if (IsControlInput(input)) {
|
||||
++num_control_inputs;
|
||||
control_input_name = input;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(num_control_inputs, 1);
|
||||
NodeDef* control_input_node = node_map->GetNode(control_input_name);
|
||||
EXPECT_TRUE(control_input_node);
|
||||
NodeDef* control_input_node = nullptr;
|
||||
GetNode(node_map, control_input_name, &control_input_node);
|
||||
EXPECT_EQ(control_input_node->op(), "_ScopedAllocator");
|
||||
return control_input_node;
|
||||
}
|
||||
|
||||
int NumControlInputs(NodeMap* node_map, const string& node_name) {
|
||||
NodeDef* node = nullptr;
|
||||
GetNode(node_map, node_name, &node);
|
||||
int num_control_inputs = 0;
|
||||
for (const auto& input : node->input()) {
|
||||
if (IsControlInput(input)) {
|
||||
++num_control_inputs;
|
||||
}
|
||||
}
|
||||
return num_control_inputs;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) {
|
||||
|
@ -287,8 +345,8 @@ TEST_F(ScopedAllocatorOptimizerTest, UnaryRewriteOnly) {
|
|||
|
||||
// Examine the resulting graph def.
|
||||
NodeMap node_map(&optimized_graph);
|
||||
NodeDef* nd = node_map.GetNode("scoped_allocator_1_1");
|
||||
ASSERT_TRUE(nd);
|
||||
NodeDef* nd = nullptr;
|
||||
GetNode(&node_map, "scoped_allocator_1_1", &nd);
|
||||
{
|
||||
auto& nd_set = node_map.GetOutputs(nd->name());
|
||||
ASSERT_EQ(3, nd_set.size());
|
||||
|
@ -420,6 +478,59 @@ TEST_F(ScopedAllocatorOptimizerTest, InputDependencies) {
|
|||
EXPECT_EQ(scoped_allocator_node->input(0), "^c");
|
||||
}
|
||||
|
||||
// Test that graphs with input and output control edges are rewired correctly by
|
||||
// the optimizer.
|
||||
TEST_F(ScopedAllocatorOptimizerTest, ControlEdgeRewire) {
|
||||
GrapplerItem item;
|
||||
BuildAbsGraphWithInputAndOutputControlEdges(&item.graph);
|
||||
SetShapes(&item.graph);
|
||||
LOG(INFO) << item.graph.DebugString();
|
||||
|
||||
ScopedAllocatorOptions opts;
|
||||
opts.add_enable_op("Abs");
|
||||
ScopedAllocatorOptimizer sao(RewriterConfig::ON, opts);
|
||||
ScopedAllocatorOptimizer::OpNameSet ons;
|
||||
ons.insert("Const");
|
||||
|
||||
GraphDef optimized_graph;
|
||||
TF_ASSERT_OK(sao.Optimize(/*cluster=*/nullptr, item, &optimized_graph));
|
||||
TF_ASSERT_OK(TopologicalSort(&optimized_graph));
|
||||
NodeMap node_map(&optimized_graph);
|
||||
LOG(INFO) << optimized_graph.DebugString();
|
||||
|
||||
// Check that ctl1 and ctl2 are now connected only to SAConcat.
|
||||
NodeDef* ctl1 = nullptr;
|
||||
GetNode(&node_map, "ctl1", &ctl1);
|
||||
const auto& ctl1_outputs = node_map.GetOutputs("ctl1");
|
||||
EXPECT_EQ(ctl1_outputs.size(), 1);
|
||||
NodeDef* sa_concat = *ctl1_outputs.begin();
|
||||
EXPECT_EQ(sa_concat->op(), "_ScopedAllocatorConcat");
|
||||
NodeDef* ctl2 = nullptr;
|
||||
GetNode(&node_map, "ctl2", &ctl2);
|
||||
const auto& ctl2_outputs = node_map.GetOutputs("ctl2");
|
||||
EXPECT_EQ(ctl2_outputs.size(), 1);
|
||||
EXPECT_EQ(*ctl2_outputs.begin(), sa_concat);
|
||||
|
||||
// Check that SAConcat has only 2 input control edges.
|
||||
EXPECT_EQ(NumControlInputs(&node_map, sa_concat->name()), 2);
|
||||
|
||||
// Check that fused node, which conceptually used to have control inputs from
|
||||
// ctl1 and ctl2 respectively, no longer has any control inputs.
|
||||
const auto& sa_concat_outputs = node_map.GetOutputs(sa_concat->name());
|
||||
EXPECT_EQ(sa_concat_outputs.size(), 1);
|
||||
NodeDef* fused_abs = *sa_concat_outputs.begin();
|
||||
EXPECT_EQ(NumControlInputs(&node_map, fused_abs->name()), 0);
|
||||
|
||||
// Check that SASplit node has control edges to ctl3, ctl4; also check that
|
||||
// those are the only control inputs on ctl3 and ctl4.
|
||||
const auto& fused_abs_outputs = node_map.GetOutputs(fused_abs->name());
|
||||
EXPECT_EQ(fused_abs_outputs.size(), 1);
|
||||
NodeDef* sa_split = *fused_abs_outputs.begin();
|
||||
EXPECT_EQ(NumControlOutputs(*sa_split, node_map), 2);
|
||||
EXPECT_EQ(NumControlInputs(&node_map, "ctl3"), 1);
|
||||
EXPECT_EQ(NumControlInputs(&node_map, "ctl4"), 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
|
Loading…
Reference in New Issue