[Grappler]

1) Skip dead branch elimination for merge nodes with control inputs, since these can create cycles in the resulting optimized graph.
2) Optimize a few utility functions.
3) Add more verbose VLOGging when topological sorting fails.

PiperOrigin-RevId: 284871268
Change-Id: I36435402d826e4737b709468d88641d7a7fa2a83
This commit is contained in:
A. Unique TensorFlower 2019-12-10 16:10:36 -08:00 committed by TensorFlower Gardener
parent 77b30d97cb
commit 74229d4736
7 changed files with 84 additions and 62 deletions

View File

@ -627,7 +627,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
return Status::OK();
}
VLOG(3) << "Try to find a zero iteration while loop:"
VLOG(4) << "Try to find a zero iteration while loop:"
<< " switch_node=" << switch_node.name();
// Find the boolean predicate from a LoopCond node (e.g. Greater).
@ -704,7 +704,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
&constant_switch_value));
if (constant_switch_value == false) {
VLOG(4) << "Remove 0 iteration while loop:"
VLOG(3) << "Remove 0 iteration while loop:"
<< " switch_node=" << switch_node.name();
*has_dead_fanout = true;
*dead_fanout = 1;
@ -746,8 +746,6 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
}
if (options_.enable_dead_branch_removal) {
// TODO(srjoglekar): Figure out if we can optimize NodeMap creations across
// optimizer passes.
NodeMap node_map(optimized_graph);
absl::flat_hash_set<string> feed_nodes;
for (const auto& feed : item.feed) {
@ -890,43 +888,55 @@ Status LoopOptimizer::RemoveDeadBranches(
// Names of the nodes that were removed from the graph.
absl::flat_hash_set<absl::string_view> dead_node_names;
dead_node_names.reserve(dead_nodes.size());
for (const NodeDef* dead_node : dead_nodes)
for (const NodeDef* dead_node : dead_nodes) {
dead_node_names.insert(dead_node->name());
}
// Remove dead inputs from Merge nodes that were not pruned from the graph.
// Check that the merge nodes are valid.
for (const auto& itr : dead_merge_inputs) {
NodeDef* dead_node = itr.first;
if (dead_nodes.find(dead_node) != dead_nodes.end()) {
// The node has been pruned since all its inputs are dead.
NodeDef* merge_node = itr.first;
if (dead_nodes.find(merge_node) != dead_nodes.end()) {
// The node will be pruned since all its inputs are dead.
continue;
}
// Remove dead data input.
const std::set<int>& dead_inputs = itr.second;
CHECK_LE(dead_inputs.size(), 1);
// (This loop would delete >1 items possibly in the wrong order.)
for (int index : dead_inputs) {
dead_node->mutable_input()->DeleteSubrange(index, 1);
const int num_data_inputs = merge_node->attr().at("N").i();
if (merge_node->input_size() != num_data_inputs) {
LOG(WARNING)
<< "Skipping loop optimization for Merge node with control input: "
<< merge_node->name();
return Status::OK();
} else if (dead_inputs.size() != 1 || num_data_inputs != 2) {
LOG(WARNING) << "Skipping loop optimization for Merge node ("
<< merge_node->name()
<< ") with unexpected dead_inputs.size() ("
<< dead_inputs.size() << " or num_data_inputs"
<< num_data_inputs;
return Status::OK();
}
// Turn Merge into Identity only if we deleted the other data input.
if (!dead_inputs.empty()) {
const int num_data_inputs = dead_node->attr().at("N").i();
CHECK_EQ(num_data_inputs, dead_inputs.size() + 1);
dead_node->set_op("Identity");
dead_node->mutable_attr()->erase("N");
}
// Remove control inputs from dead nodes.
int pos = 0;
while (pos < dead_node->input_size()) {
TensorId tensor = ParseTensorName(dead_node->input(pos));
if (tensor.index() == Graph::kControlSlot &&
dead_node_names.contains(tensor.node())) {
auto* inputs = dead_node->mutable_input();
inputs->SwapElements(pos, dead_node->input_size() - 1);
inputs->RemoveLast();
} else {
++pos;
}
}
// Remove dead inputs from Merge nodes that will not be not
// pruned from the graph.
for (const auto& itr : dead_merge_inputs) {
NodeDef* merge_node = itr.first;
if (dead_nodes.find(merge_node) != dead_nodes.end()) {
// The node will be pruned since all its inputs are dead.
continue;
}
VLOG(3) << "Merge node before cleanup: " << merge_node->DebugString();
// Remove dead data input.
const std::set<int>& dead_inputs = itr.second;
int index = *dead_inputs.begin();
auto* inputs = merge_node->mutable_input();
inputs->SwapElements(1, index);
inputs->SwapElements(1, merge_node->input_size() - 1);
inputs->RemoveLast();
merge_node->set_op("Identity");
merge_node->mutable_attr()->erase("N");
VLOG(3) << "Merge node after cleanup: " << merge_node->DebugString();
}
EraseNodesFromGraph(std::move(nodes_idx_to_delete), optimized_graph);

View File

@ -777,10 +777,6 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1});
ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2});
ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1});
ops::Merge m6(scope.WithOpName("m6").WithControlDependencies(sqrt2),
{v_in, square1});
ops::Merge m7(scope.WithOpName("m7").WithControlDependencies(sqrt1),
{v_in, square1});
ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1);
Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false);
@ -831,19 +827,6 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "square1");
EXPECT_EQ(node.input(1), "sqrt2");
} else if (node.name() == "m6") {
// both inputs are alive and the control dependency can get triggered
EXPECT_EQ(node.op(), "Merge");
ASSERT_EQ(node.input_size(), 3);
EXPECT_EQ(node.input(0), "v_in");
EXPECT_EQ(node.input(1), "square1");
EXPECT_EQ(node.input(2), "^sqrt2");
} else if (node.name() == "m7") {
// removed control input from dead sqrt1
EXPECT_EQ(node.op(), "Merge");
ASSERT_EQ(node.input_size(), 2);
EXPECT_EQ(node.input(0), "v_in");
EXPECT_EQ(node.input(1), "square1");
} else if (node.name() == "m8") {
// The node is to be preserved because of a fetch
EXPECT_EQ(node.op(), "Merge");
@ -859,11 +842,11 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
}
}
auto tensors_expected = EvaluateNodes(item.graph, {"m7", "m8", "m9"});
ASSERT_EQ(tensors_expected.size(), 3);
auto tensors_expected = EvaluateNodes(item.graph, {"m8", "m9"});
ASSERT_EQ(tensors_expected.size(), 2);
auto tensors = EvaluateNodes(output, {"m7", "m8", "m9"});
ASSERT_EQ(tensors.size(), 3);
auto tensors = EvaluateNodes(output, {"m8", "m9"});
ASSERT_EQ(tensors.size(), 2);
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-6);
@ -1098,7 +1081,6 @@ node {
op: "Merge"
input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency_1"
input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency"
input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert"
device: "/job:localhost/replica:0/task:0/device:CPU:0"
attr {
key: "N"

View File

@ -221,7 +221,7 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.function_optimization() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<FunctionOptimizer>(
cfg_.function_optimization(),
/*lower_contorl_flow=*/!IsSingleThreadedExecutor()));
/*lower_control_flow=*/!IsSingleThreadedExecutor()));
}
if (cfg_.debug_stripper() == RewriterConfig::ON) {
optimizers->push_back(MakeUnique<DebugStripper>());

View File

@ -277,10 +277,22 @@ bool HasRegularInputs(const NodeDef& node) {
}
int NumNonControlInputs(const NodeDef& node) {
int num_inputs = node.input_size();
for (const string& input : node.input()) {
int num_inputs = 0;
for (; num_inputs < node.input_size(); ++num_inputs) {
const string& input = node.input(num_inputs);
if (IsControlInput(input)) {
--num_inputs;
return num_inputs;
}
}
return num_inputs;
}
int NumControlInputs(const NodeDef& node) {
int num_inputs = 0;
for (; num_inputs < node.input_size(); ++num_inputs) {
const string& input = node.input(node.input_size() - num_inputs - 1);
if (!IsControlInput(input)) {
return num_inputs;
}
}
return num_inputs;
@ -302,8 +314,9 @@ bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) {
bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
for (const string& node_as_input : output->input()) {
if (!IsControlInput(node_as_input)) continue;
for (int idx = output->input_size() - 1; idx >= 0; --idx) {
const string& node_as_input = output->input(idx);
if (!IsControlInput(node_as_input)) break;
TensorId tensor = ParseTensorName(node_as_input);
if (tensor.node() == node.name()) {
@ -317,8 +330,9 @@ bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
int num_outputs = 0;
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
for (const string& node_as_input : output->input()) {
if (!IsControlInput(node_as_input)) continue;
for (int idx = output->input_size() - 1; idx >= 0; --idx) {
const string& node_as_input = output->input(idx);
if (!IsControlInput(node_as_input)) break;
TensorId tensor = ParseTensorName(node_as_input);
if (tensor.node() == node.name()) {

View File

@ -221,6 +221,9 @@ bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map);
// Returns true iff the node has at least one control output.
bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map);
// Number of connected control inputs.
int NumControlInputs(const NodeDef& node);
// Number of connected non-control inputs.
int NumNonControlInputs(const NodeDef& node);

View File

@ -90,6 +90,16 @@ Status ComputeTopologicalOrder(
}
if (back != graph_view.num_nodes()) {
if (VLOG_IS_ON(1)) {
VLOG(1) << "The graph couldn't be sorted in topological order. Stalled "
"at node = "
<< graph.node(back).DebugString();
for (int i = 0; i < graph_view.num_nodes(); ++i) {
if (num_ready_inputs[i] != graph_view.GetFanin(i).size()) {
VLOG(1) << "Node not ready: " << graph.node(i).DebugString();
}
}
}
return errors::InvalidArgument(
"The graph couldn't be sorted in topological order.");
}

View File

@ -352,14 +352,17 @@ TEST_F(UtilsTest, NumNonControlOutputs) {
NodeMap node_map(&graph);
const NodeDef* add_node = node_map.GetNode("add");
const NodeDef* mul_node = node_map.GetNode("mul");
ASSERT_NE(add_node, nullptr);
// [a, b] are only non-control inputs
EXPECT_EQ(NumNonControlInputs(*add_node), 2);
EXPECT_EQ(NumControlInputs(*add_node), 1);
// [sqrt, shape] are non control outputs
EXPECT_EQ(NumNonControlOutputs(*add_node, node_map), 2);
// sqrt is the only data output
EXPECT_EQ(NumNonControlDataOutputs(*add_node, node_map), 1);
EXPECT_EQ(NumControlInputs(*mul_node), 0);
EXPECT_TRUE(HasControlInputs(*add_node));
EXPECT_TRUE(HasRegularInputs(*add_node));