[Grappler]
1) Skip dead branch elimination for merge nodes with control inputs, since these can create cycles in the resulting optimized graph. 2) Optimize a few utility functions. 3) Add more verbose VLOGging when topological sorting fails. PiperOrigin-RevId: 284871268 Change-Id: I36435402d826e4737b709468d88641d7a7fa2a83
This commit is contained in:
parent
77b30d97cb
commit
74229d4736
@ -627,7 +627,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(3) << "Try to find a zero iteration while loop:"
|
VLOG(4) << "Try to find a zero iteration while loop:"
|
||||||
<< " switch_node=" << switch_node.name();
|
<< " switch_node=" << switch_node.name();
|
||||||
|
|
||||||
// Find the boolean predicate from a LoopCond node (e.g. Greater).
|
// Find the boolean predicate from a LoopCond node (e.g. Greater).
|
||||||
@ -704,7 +704,7 @@ Status CheckForDeadFanout(const MutableGraphView& view,
|
|||||||
&constant_switch_value));
|
&constant_switch_value));
|
||||||
|
|
||||||
if (constant_switch_value == false) {
|
if (constant_switch_value == false) {
|
||||||
VLOG(4) << "Remove 0 iteration while loop:"
|
VLOG(3) << "Remove 0 iteration while loop:"
|
||||||
<< " switch_node=" << switch_node.name();
|
<< " switch_node=" << switch_node.name();
|
||||||
*has_dead_fanout = true;
|
*has_dead_fanout = true;
|
||||||
*dead_fanout = 1;
|
*dead_fanout = 1;
|
||||||
@ -746,8 +746,6 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
|
TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
|
||||||
}
|
}
|
||||||
if (options_.enable_dead_branch_removal) {
|
if (options_.enable_dead_branch_removal) {
|
||||||
// TODO(srjoglekar): Figure out if we can optimize NodeMap creations across
|
|
||||||
// optimizer passes.
|
|
||||||
NodeMap node_map(optimized_graph);
|
NodeMap node_map(optimized_graph);
|
||||||
absl::flat_hash_set<string> feed_nodes;
|
absl::flat_hash_set<string> feed_nodes;
|
||||||
for (const auto& feed : item.feed) {
|
for (const auto& feed : item.feed) {
|
||||||
@ -890,43 +888,55 @@ Status LoopOptimizer::RemoveDeadBranches(
|
|||||||
// Names of the nodes that were removed from the graph.
|
// Names of the nodes that were removed from the graph.
|
||||||
absl::flat_hash_set<absl::string_view> dead_node_names;
|
absl::flat_hash_set<absl::string_view> dead_node_names;
|
||||||
dead_node_names.reserve(dead_nodes.size());
|
dead_node_names.reserve(dead_nodes.size());
|
||||||
for (const NodeDef* dead_node : dead_nodes)
|
for (const NodeDef* dead_node : dead_nodes) {
|
||||||
dead_node_names.insert(dead_node->name());
|
dead_node_names.insert(dead_node->name());
|
||||||
|
}
|
||||||
|
|
||||||
// Remove dead inputs from Merge nodes that were not pruned from the graph.
|
// Check that the merge nodes are valid.
|
||||||
for (const auto& itr : dead_merge_inputs) {
|
for (const auto& itr : dead_merge_inputs) {
|
||||||
NodeDef* dead_node = itr.first;
|
NodeDef* merge_node = itr.first;
|
||||||
if (dead_nodes.find(dead_node) != dead_nodes.end()) {
|
if (dead_nodes.find(merge_node) != dead_nodes.end()) {
|
||||||
// The node has been pruned since all its inputs are dead.
|
// The node will be pruned since all its inputs are dead.
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// Remove dead data input.
|
// Remove dead data input.
|
||||||
const std::set<int>& dead_inputs = itr.second;
|
const std::set<int>& dead_inputs = itr.second;
|
||||||
CHECK_LE(dead_inputs.size(), 1);
|
const int num_data_inputs = merge_node->attr().at("N").i();
|
||||||
// (This loop would delete >1 items possibly in the wrong order.)
|
if (merge_node->input_size() != num_data_inputs) {
|
||||||
for (int index : dead_inputs) {
|
LOG(WARNING)
|
||||||
dead_node->mutable_input()->DeleteSubrange(index, 1);
|
<< "Skipping loop optimization for Merge node with control input: "
|
||||||
|
<< merge_node->name();
|
||||||
|
return Status::OK();
|
||||||
|
} else if (dead_inputs.size() != 1 || num_data_inputs != 2) {
|
||||||
|
LOG(WARNING) << "Skipping loop optimization for Merge node ("
|
||||||
|
<< merge_node->name()
|
||||||
|
<< ") with unexpected dead_inputs.size() ("
|
||||||
|
<< dead_inputs.size() << " or num_data_inputs"
|
||||||
|
<< num_data_inputs;
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
// Turn Merge into Identity only if we deleted the other data input.
|
}
|
||||||
if (!dead_inputs.empty()) {
|
|
||||||
const int num_data_inputs = dead_node->attr().at("N").i();
|
// Remove dead inputs from Merge nodes that will not be not
|
||||||
CHECK_EQ(num_data_inputs, dead_inputs.size() + 1);
|
// pruned from the graph.
|
||||||
dead_node->set_op("Identity");
|
for (const auto& itr : dead_merge_inputs) {
|
||||||
dead_node->mutable_attr()->erase("N");
|
NodeDef* merge_node = itr.first;
|
||||||
}
|
if (dead_nodes.find(merge_node) != dead_nodes.end()) {
|
||||||
// Remove control inputs from dead nodes.
|
// The node will be pruned since all its inputs are dead.
|
||||||
int pos = 0;
|
continue;
|
||||||
while (pos < dead_node->input_size()) {
|
|
||||||
TensorId tensor = ParseTensorName(dead_node->input(pos));
|
|
||||||
if (tensor.index() == Graph::kControlSlot &&
|
|
||||||
dead_node_names.contains(tensor.node())) {
|
|
||||||
auto* inputs = dead_node->mutable_input();
|
|
||||||
inputs->SwapElements(pos, dead_node->input_size() - 1);
|
|
||||||
inputs->RemoveLast();
|
|
||||||
} else {
|
|
||||||
++pos;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
VLOG(3) << "Merge node before cleanup: " << merge_node->DebugString();
|
||||||
|
// Remove dead data input.
|
||||||
|
const std::set<int>& dead_inputs = itr.second;
|
||||||
|
int index = *dead_inputs.begin();
|
||||||
|
auto* inputs = merge_node->mutable_input();
|
||||||
|
inputs->SwapElements(1, index);
|
||||||
|
inputs->SwapElements(1, merge_node->input_size() - 1);
|
||||||
|
inputs->RemoveLast();
|
||||||
|
merge_node->set_op("Identity");
|
||||||
|
merge_node->mutable_attr()->erase("N");
|
||||||
|
|
||||||
|
VLOG(3) << "Merge node after cleanup: " << merge_node->DebugString();
|
||||||
}
|
}
|
||||||
|
|
||||||
EraseNodesFromGraph(std::move(nodes_idx_to_delete), optimized_graph);
|
EraseNodesFromGraph(std::move(nodes_idx_to_delete), optimized_graph);
|
||||||
|
@ -777,10 +777,6 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
|
|||||||
ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1});
|
ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1});
|
||||||
ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2});
|
ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2});
|
||||||
ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1});
|
ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1});
|
||||||
ops::Merge m6(scope.WithOpName("m6").WithControlDependencies(sqrt2),
|
|
||||||
{v_in, square1});
|
|
||||||
ops::Merge m7(scope.WithOpName("m7").WithControlDependencies(sqrt1),
|
|
||||||
{v_in, square1});
|
|
||||||
|
|
||||||
ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1);
|
ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1);
|
||||||
Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false);
|
Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false);
|
||||||
@ -831,19 +827,6 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
|
|||||||
ASSERT_EQ(node.input_size(), 2);
|
ASSERT_EQ(node.input_size(), 2);
|
||||||
EXPECT_EQ(node.input(0), "square1");
|
EXPECT_EQ(node.input(0), "square1");
|
||||||
EXPECT_EQ(node.input(1), "sqrt2");
|
EXPECT_EQ(node.input(1), "sqrt2");
|
||||||
} else if (node.name() == "m6") {
|
|
||||||
// both inputs are alive and the control dependency can get triggered
|
|
||||||
EXPECT_EQ(node.op(), "Merge");
|
|
||||||
ASSERT_EQ(node.input_size(), 3);
|
|
||||||
EXPECT_EQ(node.input(0), "v_in");
|
|
||||||
EXPECT_EQ(node.input(1), "square1");
|
|
||||||
EXPECT_EQ(node.input(2), "^sqrt2");
|
|
||||||
} else if (node.name() == "m7") {
|
|
||||||
// removed control input from dead sqrt1
|
|
||||||
EXPECT_EQ(node.op(), "Merge");
|
|
||||||
ASSERT_EQ(node.input_size(), 2);
|
|
||||||
EXPECT_EQ(node.input(0), "v_in");
|
|
||||||
EXPECT_EQ(node.input(1), "square1");
|
|
||||||
} else if (node.name() == "m8") {
|
} else if (node.name() == "m8") {
|
||||||
// The node is to be preserved because of a fetch
|
// The node is to be preserved because of a fetch
|
||||||
EXPECT_EQ(node.op(), "Merge");
|
EXPECT_EQ(node.op(), "Merge");
|
||||||
@ -859,11 +842,11 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto tensors_expected = EvaluateNodes(item.graph, {"m7", "m8", "m9"});
|
auto tensors_expected = EvaluateNodes(item.graph, {"m8", "m9"});
|
||||||
ASSERT_EQ(tensors_expected.size(), 3);
|
ASSERT_EQ(tensors_expected.size(), 2);
|
||||||
|
|
||||||
auto tensors = EvaluateNodes(output, {"m7", "m8", "m9"});
|
auto tensors = EvaluateNodes(output, {"m8", "m9"});
|
||||||
ASSERT_EQ(tensors.size(), 3);
|
ASSERT_EQ(tensors.size(), 2);
|
||||||
|
|
||||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
||||||
test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-6);
|
test::ExpectTensorNear<float>(tensors_expected[1], tensors[1], 1e-6);
|
||||||
@ -1098,7 +1081,6 @@ node {
|
|||||||
op: "Merge"
|
op: "Merge"
|
||||||
input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency_1"
|
input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency_1"
|
||||||
input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency"
|
input: "EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/control_dependency"
|
||||||
input: "^EpisodicReplayBuffer/add/assert_equal/Assert/AssertGuard/Assert"
|
|
||||||
device: "/job:localhost/replica:0/task:0/device:CPU:0"
|
device: "/job:localhost/replica:0/task:0/device:CPU:0"
|
||||||
attr {
|
attr {
|
||||||
key: "N"
|
key: "N"
|
||||||
|
@ -221,7 +221,7 @@ Status MetaOptimizer::InitializeOptimizers(
|
|||||||
if (cfg_.function_optimization() != RewriterConfig::OFF) {
|
if (cfg_.function_optimization() != RewriterConfig::OFF) {
|
||||||
optimizers->push_back(MakeUnique<FunctionOptimizer>(
|
optimizers->push_back(MakeUnique<FunctionOptimizer>(
|
||||||
cfg_.function_optimization(),
|
cfg_.function_optimization(),
|
||||||
/*lower_contorl_flow=*/!IsSingleThreadedExecutor()));
|
/*lower_control_flow=*/!IsSingleThreadedExecutor()));
|
||||||
}
|
}
|
||||||
if (cfg_.debug_stripper() == RewriterConfig::ON) {
|
if (cfg_.debug_stripper() == RewriterConfig::ON) {
|
||||||
optimizers->push_back(MakeUnique<DebugStripper>());
|
optimizers->push_back(MakeUnique<DebugStripper>());
|
||||||
|
@ -277,10 +277,22 @@ bool HasRegularInputs(const NodeDef& node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int NumNonControlInputs(const NodeDef& node) {
|
int NumNonControlInputs(const NodeDef& node) {
|
||||||
int num_inputs = node.input_size();
|
int num_inputs = 0;
|
||||||
for (const string& input : node.input()) {
|
for (; num_inputs < node.input_size(); ++num_inputs) {
|
||||||
|
const string& input = node.input(num_inputs);
|
||||||
if (IsControlInput(input)) {
|
if (IsControlInput(input)) {
|
||||||
--num_inputs;
|
return num_inputs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return num_inputs;
|
||||||
|
}
|
||||||
|
|
||||||
|
int NumControlInputs(const NodeDef& node) {
|
||||||
|
int num_inputs = 0;
|
||||||
|
for (; num_inputs < node.input_size(); ++num_inputs) {
|
||||||
|
const string& input = node.input(node.input_size() - num_inputs - 1);
|
||||||
|
if (!IsControlInput(input)) {
|
||||||
|
return num_inputs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return num_inputs;
|
return num_inputs;
|
||||||
@ -302,8 +314,9 @@ bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map) {
|
|||||||
|
|
||||||
bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
|
bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
|
||||||
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
||||||
for (const string& node_as_input : output->input()) {
|
for (int idx = output->input_size() - 1; idx >= 0; --idx) {
|
||||||
if (!IsControlInput(node_as_input)) continue;
|
const string& node_as_input = output->input(idx);
|
||||||
|
if (!IsControlInput(node_as_input)) break;
|
||||||
|
|
||||||
TensorId tensor = ParseTensorName(node_as_input);
|
TensorId tensor = ParseTensorName(node_as_input);
|
||||||
if (tensor.node() == node.name()) {
|
if (tensor.node() == node.name()) {
|
||||||
@ -317,8 +330,9 @@ bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map) {
|
|||||||
int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
|
int NumControlOutputs(const NodeDef& node, const NodeMap& node_map) {
|
||||||
int num_outputs = 0;
|
int num_outputs = 0;
|
||||||
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
for (const NodeDef* output : node_map.GetOutputs(node.name())) {
|
||||||
for (const string& node_as_input : output->input()) {
|
for (int idx = output->input_size() - 1; idx >= 0; --idx) {
|
||||||
if (!IsControlInput(node_as_input)) continue;
|
const string& node_as_input = output->input(idx);
|
||||||
|
if (!IsControlInput(node_as_input)) break;
|
||||||
|
|
||||||
TensorId tensor = ParseTensorName(node_as_input);
|
TensorId tensor = ParseTensorName(node_as_input);
|
||||||
if (tensor.node() == node.name()) {
|
if (tensor.node() == node.name()) {
|
||||||
|
@ -221,6 +221,9 @@ bool HasRegularOutputs(const NodeDef& node, const NodeMap& node_map);
|
|||||||
// Returns true iff the node has at least one control output.
|
// Returns true iff the node has at least one control output.
|
||||||
bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map);
|
bool HasControlOutputs(const NodeDef& node, const NodeMap& node_map);
|
||||||
|
|
||||||
|
// Number of connected control inputs.
|
||||||
|
int NumControlInputs(const NodeDef& node);
|
||||||
|
|
||||||
// Number of connected non-control inputs.
|
// Number of connected non-control inputs.
|
||||||
int NumNonControlInputs(const NodeDef& node);
|
int NumNonControlInputs(const NodeDef& node);
|
||||||
|
|
||||||
|
@ -90,6 +90,16 @@ Status ComputeTopologicalOrder(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (back != graph_view.num_nodes()) {
|
if (back != graph_view.num_nodes()) {
|
||||||
|
if (VLOG_IS_ON(1)) {
|
||||||
|
VLOG(1) << "The graph couldn't be sorted in topological order. Stalled "
|
||||||
|
"at node = "
|
||||||
|
<< graph.node(back).DebugString();
|
||||||
|
for (int i = 0; i < graph_view.num_nodes(); ++i) {
|
||||||
|
if (num_ready_inputs[i] != graph_view.GetFanin(i).size()) {
|
||||||
|
VLOG(1) << "Node not ready: " << graph.node(i).DebugString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"The graph couldn't be sorted in topological order.");
|
"The graph couldn't be sorted in topological order.");
|
||||||
}
|
}
|
||||||
|
@ -352,14 +352,17 @@ TEST_F(UtilsTest, NumNonControlOutputs) {
|
|||||||
NodeMap node_map(&graph);
|
NodeMap node_map(&graph);
|
||||||
|
|
||||||
const NodeDef* add_node = node_map.GetNode("add");
|
const NodeDef* add_node = node_map.GetNode("add");
|
||||||
|
const NodeDef* mul_node = node_map.GetNode("mul");
|
||||||
ASSERT_NE(add_node, nullptr);
|
ASSERT_NE(add_node, nullptr);
|
||||||
|
|
||||||
// [a, b] are only non-control inputs
|
// [a, b] are only non-control inputs
|
||||||
EXPECT_EQ(NumNonControlInputs(*add_node), 2);
|
EXPECT_EQ(NumNonControlInputs(*add_node), 2);
|
||||||
|
EXPECT_EQ(NumControlInputs(*add_node), 1);
|
||||||
// [sqrt, shape] are non control outputs
|
// [sqrt, shape] are non control outputs
|
||||||
EXPECT_EQ(NumNonControlOutputs(*add_node, node_map), 2);
|
EXPECT_EQ(NumNonControlOutputs(*add_node, node_map), 2);
|
||||||
// sqrt is the only data output
|
// sqrt is the only data output
|
||||||
EXPECT_EQ(NumNonControlDataOutputs(*add_node, node_map), 1);
|
EXPECT_EQ(NumNonControlDataOutputs(*add_node, node_map), 1);
|
||||||
|
EXPECT_EQ(NumControlInputs(*mul_node), 0);
|
||||||
|
|
||||||
EXPECT_TRUE(HasControlInputs(*add_node));
|
EXPECT_TRUE(HasControlInputs(*add_node));
|
||||||
EXPECT_TRUE(HasRegularInputs(*add_node));
|
EXPECT_TRUE(HasRegularInputs(*add_node));
|
||||||
|
Loading…
Reference in New Issue
Block a user