[Grappler]
1) Skip dead branch elimination for merge nodes with control inputs, since these can create cycles in the resulting optimized graph. 2) Optimize a few utility functions. 3) Add more verbose VLOGging when topological sorting fails. PiperOrigin-RevId: 284871268 Change-Id: I36435402d826e4737b709468d88641d7a7fa2a83
This commit is contained in:
parent
77b30d97cb
commit
74229d4736
@ -627,7 +627,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(3) << "Try to find a zero iteration while loop:"
|
||||
VLOG(4) << "Try to find a zero iteration while loop:"
|
||||
<< " switch_node=" << switch_node.name();
|
||||
|
||||
// Find the boolean predicate from a LoopCond node (e.g. Greater).
|
||||
@ -704,7 +704,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
|
||||
&constant_switch_value));
|
||||
|
||||
if (constant_switch_value == false) {
|
||||
VLOG(4) << "Remove 0 iteration while loop:"
|
||||
VLOG(3) << "Remove 0 iteration while loop:"
|
||||
<< " switch_node=" << switch_node.name();
|
||||
*has_dead_fanout = true;
|
||||
*dead_fanout = 1;
|
||||
@ -746,8 +746,6 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
|
||||
}
|
||||
if (options_.enable_dead_branch_removal) {
|
||||
// TODO(srjoglekar): Figure out if we can optimize NodeMap creations across
|
||||
// optimizer passes.
|
||||
NodeMap node_map(optimized_graph);
|
||||
absl::flat_hash_set<string> feed_nodes;
|
||||
for (const auto& feed : item.feed) {
|
||||
@ -890,43 +888,55 @@ Status LoopOptimizer::RemoveDeadBranches(
|
||||
// Names of the nodes that were removed from the graph.
|
||||
absl::flat_hash_set<absl::string_view> dead_node_names;
|
||||
dead_node_names.reserve(dead_nodes.size());
|
||||
for (const NodeDef* dead_node : dead_nodes)
|
||||
for (const NodeDef* dead_node : dead_nodes) {
|
||||
dead_node_names.insert(dead_node->name());
|
||||
}
|
||||
|
||||
// Remove dead inputs from Merge nodes that were not pruned from the graph.
|
||||
// Check that the merge nodes are valid.
|
||||
for (const auto& itr : dead_merge_inputs) {
|
||||
NodeDef* dead_node = itr.first;
|
||||
if (dead_nodes.find(dead_node) != dead_nodes.end()) {
|
||||
// The node has been pruned since all its inputs are dead.
|
||||
NodeDef* merge_node = itr.first;
|
||||
if (dead_nodes.find(merge_node) != dead_nodes.end()) {
|
||||
// The node will be pruned since all its inputs are dead.
|
||||
continue;
|
||||
}
|
||||
// Remove dead data input.
|
||||
const std::set<int>& dead_inputs = itr.second;
|
||||
CHECK_LE(dead_inputs.size(), 1);
|
||||
// (This loop would delete >1 items possibly in the wrong order.)
|
||||
for (int index : dead_inputs) {
|
||||
dead_node->mutable_input()->DeleteSubrange(index, 1);
|
||||
const int num_data_inputs = merge_node->attr().at("N").i();
|
||||
if (merge_node->input_size() != num_data_inputs) {
|
||||
LOG(WARNING)
|
||||
<< "Skipping loop optimization for Merge node with control input: "
|
||||
<< merge_node->name();
|
||||
return Status::OK();
|
||||
} else if (dead_inputs.size() != 1 || num_data_inputs != 2) {
|
||||
LOG(WARNING) << "Skipping loop optimization for Merge node ("
|
||||
<< merge_node->name()
|
||||
<< ") with unexpected dead_inputs.size() ("
|
||||
<< dead_inputs.size() << " or num_data_inputs"
|
||||
<< num_data_inputs;
|
||||
return Status::OK();
|
||||
}
|
||||
// Turn Merge into Identity only if we deleted the other data input.
|
||||
if (!dead_inputs.empty()) {
|
||||
const int num_data_inputs = dead_node->attr().at("N").i();
|
||||
CHECK_EQ(num_data_inputs, dead_inputs.size() + 1);
|
||||
dead_node->set_op("Identity");
|
||||
dead_node->mutable_attr()->erase("N");
|
||||
}
|
||||
// Remove control inputs from dead nodes.
|
||||
int pos = 0;
|
||||
while (pos < dead_node->input_size()) {
|
||||
TensorId tensor = ParseTensorName(dead_node->input(pos));
|
||||
if (tensor.index() == Graph::kControlSlot &&
|
||||
dead_node_names.contains(tensor.node())) {
|
||||
auto* inputs = dead_node->mutable_input();
|
||||
inputs->SwapElements(pos, dead_node->input_size() - 1);
|
||||
inputs->RemoveLast();
|
||||
} else {
|
||||
++pos;
|
||||
}
|
||||
}
|
||||
|
||||
// Remove dead inputs from Merge nodes that will not be not
|
||||
// pruned from the graph.
|
||||
for (const auto& itr : dead_merge_inputs) {
|
||||
NodeDef* merge_node = itr.first;
|
||||
if (dead_nodes.find(merge_node) != dead_nodes.end()) {
|
||||
// The node will be pruned since all its inputs are dead.
|
||||
continue;
|
||||
}
|
||||
VLOG(3) << "Merge node before cleanup: " << merge_node->DebugString();
|
||||
// Remove dead data input.
|
||||
const std::set<int>& dead_inputs = itr.second;
|
||||
int index = *dead_inputs.begin();
|
||||
auto* inputs = merge_node->mutable_input();
|
||||
inputs->SwapElements(1, index);
|
||||
inputs->SwapElements(1, merge_node->input_size() - 1);
|
||||
inputs->RemoveLast();
|
||||
merge_node->set_op("Identity");
|
||||
merge_node->mutable_attr()->erase("N");
|
||||
|
||||
VLOG(3) << "Merge node after cleanup: " << merge_node->DebugString();
|
||||
}
|
||||
|
||||
EraseNodesFromGraph(std::move(nodes_idx_to_delete), optimized_graph);
|
||||
|
@ -777,10 +777,6 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
|
||||
ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1});
|
||||
ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2});
|
||||
ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1});
|
||||
ops::Merge m6(scope.WithOpName("m6").WithControlDependencies(sqrt2),
|
||||
{v_in, square1});
|
||||
ops::Merge m7(scope.WithOpName("m7").WithControlDependencies(sqrt1),
|
||||
{v_in, square1});
|
||||
|
||||
ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1);
|
||||
Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false);
|
||||
@ -831,19 +827,6 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
|
||||
ASSERT_EQ(node.input_size(), 2);
|
||||
EXPECT_EQ(node.input(0), "square1");
|
||||
EXPECT_EQ(node.input(1), "sqrt2");
|
||||
} else if (node.name() == "m6") {
|
||||
// both inputs are alive and the control dependency can get triggered
|
||||
EXPECT_EQ(node.op(), "Merge");
|
||||
ASSERT_EQ(node.input_size(), 3);
|
||||
EXPECT_EQ(node.input(0), "v_in");
|
||||
EXPECT_EQ(node.input(1), "square1");
|
||||
EXPECT_EQ(node.input(2), "^sqrt2");
|
||||
} else if (node.name() == "m7") {
|
||||
// removed control input from dead sqrt1
|
||||
EXPECT_EQ(node.op(), "Merge");
|
||||
ASSERT_EQ(node.input_size(), 2);
|
||||
EXPECT_EQ(node.input(0), "v_in");
|
||||
EXPECT_EQ(node.input(1), "square1");
|
||||
} else if (node.name() == "m8") {
|
||||
// The node is to be preserved because of a fetch
|
||||
EXPECT_EQ(node.op(), "Merge");
|
||||
@ -859,11 +842,11 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
|
||||
}
|
||||
}
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, {"m7", "m8", "m9"});
|
||||
ASSERT_EQ(tensors_expected.size(), 3);
|
||||
auto tensors_expected = EvaluateNodes(item.graph, {"m8", "m9"});
|
||||
ASSERT_EQ(tensors_expected.size(), 2);
|
||||
|
||||
auto tensors = EvaluateNodes(output, {"m7", "m8", "m9"});
|
||||
ASSERT_EQ(tensors.size(), 3);
|
||||
auto tensors = EvaluateNodes(output, {"m8", "m9"});
|
||||
ASSERT_EQ(tensors.size(), 2);
|
||||
|
||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
||||
test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-6);
|
||||
@ -1098,7 +1081,6 @@ node {
|
||||
op: "Merge"
|
||||
input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency_1"
|
||||
input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency"
|
||||
input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert"
|
||||
device: "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||
attr {
|
||||
key: "N"
|
||||
|
@ -221,7 +221,7 @@ Status MetaOptimizer::InitializeOptimizers(
|
||||
if (cfg_.function_optimization() != RewriterConfig::OFF) {
|
||||
optimizers->push_back(MakeUnique<FunctionOptimizer>(
|
||||
cfg_.function_optimization(),
|
||||
/*lower_contorl_flow=*/!IsSingleThreadedExecutor()));
|
||||
/*lower_control_flow=*/!IsSingleThreadedExecutor()));
|
||||
}
|
||||
if (cfg_.debug_stripper() == RewriterConfig::ON) {
|
||||
optimizers->push_back(MakeUnique<DebugStripper>());
|
||||
|
@ -277,10 +277,22 @@ bool HasRegularInputs(const NodeDef& node) {
|
||||
}
|
||||
|
||||
int NumNonControlInputs(const NodeDef& node) {
|
||||
int num_inputs = node.input_size();
|
||||
for (const string& input : node.input()) {
|
||||
int num_inputs = 0;
|
||||
for (; num_inputs < node.input_size(); ++num_inputs) {
|
||||
const string& input = node.input(num_inputs);
|
||||
if (IsControlInput(input)) {
|
||||
--num_inputs;
|
||||
return num_inputs;
|
||||
}
|
||||
}
|
||||
return num_inputs;
|
||||
}
|
||||
|
||||
int NumControlInputs(const NodeDef& node) {
|
||||
int num_inputs = 0;
|
||||
for (; num_inputs < node.input_size(); ++num_inputs) {
|
||||
const string& input = node.input(node.input_size() - num_inputs - 1);
|
||||
if (!IsControlInput(input)) {
|
||||
return num_inputs;
|
||||
}
|
||||
}
|
||||
return num_inputs;
|
||||
@ -302,8 +314,9 @@ bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) {
|
||||
|
||||
bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
|
||||
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
||||
for (const string& node_as_input : output->input()) {
|
||||
if (!IsControlInput(node_as_input)) continue;
|
||||
for (int idx = output->input_size() - 1; idx >= 0; --idx) {
|
||||
const string& node_as_input = output->input(idx);
|
||||
if (!IsControlInput(node_as_input)) break;
|
||||
|
||||
TensorId tensor = ParseTensorName(node_as_input);
|
||||
if (tensor.node() == node.name()) {
|
||||
@ -317,8 +330,9 @@ bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
|
||||
int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
|
||||
int num_outputs = 0;
|
||||
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
||||
for (const string& node_as_input : output->input()) {
|
||||
if (!IsControlInput(node_as_input)) continue;
|
||||
for (int idx = output->input_size() - 1; idx >= 0; --idx) {
|
||||
const string& node_as_input = output->input(idx);
|
||||
if (!IsControlInput(node_as_input)) break;
|
||||
|
||||
TensorId tensor = ParseTensorName(node_as_input);
|
||||
if (tensor.node() == node.name()) {
|
||||
|
@ -221,6 +221,9 @@ bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map);
|
||||
// Returns true iff the node has at least one control output.
|
||||
bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map);
|
||||
|
||||
// Number of connected control inputs.
|
||||
int NumControlInputs(const NodeDef& node);
|
||||
|
||||
// Number of connected non-control inputs.
|
||||
int NumNonControlInputs(const NodeDef& node);
|
||||
|
||||
|
@ -90,6 +90,16 @@ Status ComputeTopologicalOrder(
|
||||
}
|
||||
|
||||
if (back != graph_view.num_nodes()) {
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << "The graph couldn't be sorted in topological order. Stalled "
|
||||
"at node = "
|
||||
<< graph.node(back).DebugString();
|
||||
for (int i = 0; i < graph_view.num_nodes(); ++i) {
|
||||
if (num_ready_inputs[i] != graph_view.GetFanin(i).size()) {
|
||||
VLOG(1) << "Node not ready: " << graph.node(i).DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
return errors::InvalidArgument(
|
||||
"The graph couldn't be sorted in topological order.");
|
||||
}
|
||||
|
@ -352,14 +352,17 @@ TEST_F(UtilsTest, NumNonControlOutputs) {
|
||||
NodeMap node_map(&graph);
|
||||
|
||||
const NodeDef* add_node = node_map.GetNode("add");
|
||||
const NodeDef* mul_node = node_map.GetNode("mul");
|
||||
ASSERT_NE(add_node, nullptr);
|
||||
|
||||
// [a, b] are only non-control inputs
|
||||
EXPECT_EQ(NumNonControlInputs(*add_node), 2);
|
||||
EXPECT_EQ(NumControlInputs(*add_node), 1);
|
||||
// [sqrt, shape] are non control outputs
|
||||
EXPECT_EQ(NumNonControlOutputs(*add_node, node_map), 2);
|
||||
// sqrt is the only data output
|
||||
EXPECT_EQ(NumNonControlDataOutputs(*add_node, node_map), 1);
|
||||
EXPECT_EQ(NumControlInputs(*mul_node), 0);
|
||||
|
||||
EXPECT_TRUE(HasControlInputs(*add_node));
|
||||
EXPECT_TRUE(HasRegularInputs(*add_node));
|
||||
|
Loading…
Reference in New Issue
Block a user