Extend Identity optimizations to IdentityN.

PiperOrigin-RevId: 219327001
This commit is contained in:
A. Unique TensorFlower 2018-10-30 10:48:50 -07:00 committed by TensorFlower Gardener
parent 55484438d7
commit 507c566376
3 changed files with 101 additions and 31 deletions

View File

@ -57,7 +57,7 @@ bool RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) {
} // namespace
bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
if (!IsIdentity(node) && !IsIdentityNSingleInput(node)) {
if (!IsIdentity(node) && !IsIdentityN(node)) {
return true;
}
@ -133,15 +133,56 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
return true;
}
int DependencyOptimizer::NumEdgesIfBypassed(
const NodeDef& node, const std::vector<NodeDef*>& output_nodes) const {
const bool is_multi_input_identity_n =
IsIdentityN(node) && !IsIdentityNSingleInput(node);
const int num_outputs = output_nodes.size();
const int num_inputs = node.input_size();
if (is_multi_input_identity_n) {
// multi-input identity_n with input/output control dependencies will likely
// increase number of edges after optimization.
int num_edges_if_bypassed(0);
for (string input_node_name : node.input()) {
if (IsControlInput(input_node_name)) {
num_edges_if_bypassed += num_outputs;
} else {
++num_edges_if_bypassed;
}
}
for (auto consumer : output_nodes) {
for (int j = 0; j < consumer->input_size(); ++j) {
const string& consumer_input = consumer->input(j);
int consumer_input_pos;
StringPiece consumer_input_node_name =
ParseNodeNameAsStringPiece(consumer_input, &consumer_input_pos);
if (consumer_input_node_name == node.name()) {
if (IsControlInput(consumer_input)) {
num_edges_if_bypassed += num_inputs;
} else {
++num_edges_if_bypassed;
}
}
}
}
return num_edges_if_bypassed;
} else {
return num_inputs * num_outputs;
}
}
bool DependencyOptimizer::BypassingNodeIsBeneficial(
const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
const std::vector<NodeDef*>& output_nodes) const {
const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node);
const bool is_multi_input_identity_n =
IsIdentityN(node) && !IsIdentityNSingleInput(node);
const int num_outputs = output_nodes.size();
const int num_inputs = node.input_size();
// Don't increase the number of edges in the graph.
if (num_inputs * num_outputs > num_inputs + num_outputs) {
if (NumEdgesIfBypassed(node, output_nodes) > num_inputs + num_outputs) {
return false;
}
@ -166,7 +207,9 @@ bool DependencyOptimizer::BypassingNodeIsBeneficial(
for (NodeDef* output_node : output_nodes) {
num_cross_out += static_cast<int>(output_node->device() != node_dev);
}
if (is_identity && num_cross_in > 0 && num_cross_out > 0) {
if ((is_identity || is_multi_input_identity_n) && num_cross_in > 0 &&
num_cross_out > 0) {
// This identity node follows a device crossing, so it might be
// following a _Recv node after partioning. Do not remove such nodes,
// unless they only have consumers on the same device as themselves.
@ -194,6 +237,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
NodeDef* node = optimized_graph_->mutable_node(node_idx);
const bool is_noop = IsNoOp(*node);
const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node);
const bool is_multi_input_identity =
IsIdentityN(*node) && !IsIdentityNSingleInput(*node);
const string node_name = node->name();
// Constant nodes with no input control dependency are always executed early,
// so we can prune all their output control dependencies.
@ -315,7 +360,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
// y --^> | | --^> b /\ +---+
// +----------+ y --^> b
if (is_noop || (is_identity && SafeToRemoveIdentity(*node))) {
if (is_noop || ((is_identity || is_multi_input_identity) &&
SafeToRemoveIdentity(*node))) {
const auto& output_node_set = node_map_->GetOutputs(node_name);
const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
output_node_set.end());
@ -343,11 +389,11 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
const NodeDef* input = input_nodes[i];
// Forward dependency from input to consumer if it doesn't already
// depend on it.
if (is_identity && i == 0) {
if ((is_identity && i == 0) ||
(is_multi_input_identity && !IsControlInput(node->input(i)))) {
// Replace regular input from Identity node.
bool found_input = false;
string new_input;
const string& input_to_forward = node->input(0);
const string& input_to_forward = node->input(i);
CHECK(!IsControlInput(input_to_forward));
for (int j = 0; j < consumer->input_size(); ++j) {
const string& old_input = consumer->input(j);
@ -355,22 +401,19 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
StringPiece old_input_node_name =
ParseNodeNameAsStringPiece(old_input, &old_input_pos);
if (old_input_node_name == node_name) {
if (old_input_pos >= 0) {
if (old_input_pos == i) {
// Regular input
new_input = input_to_forward;
node_map_->UpdateInput(consumer->name(), old_input, new_input);
consumer->set_input(j, new_input);
found_input = true;
} else {
} else if (old_input_pos == -1) {
// Control dependency
new_input = AsControlDependency(NodeName(input_to_forward));
node_map_->UpdateInput(consumer->name(), old_input, new_input);
consumer->set_input(j, new_input);
found_input = true;
}
}
}
CHECK(found_input);
updated_consumer = true;
} else {
// Forward dependency from input to consumer if it doesn't already
@ -415,7 +458,7 @@ Status DependencyOptimizer::OptimizeDependencies() {
std::set<int> nodes_to_delete;
for (int i = 0; i < optimized_graph_->node_size(); ++i) {
const NodeDef& node = optimized_graph_->node(i);
if (IsNoOp(node) || IsIdentity(node) || IsIdentityNSingleInput(node) ||
if (IsNoOp(node) || IsIdentity(node) || IsIdentityN(node) ||
IsConstant(node) || SafeToConvertToNoOp(node)) {
nodes_to_simplify.PushBack(i);
}

View File

@ -48,7 +48,8 @@ class DependencyOptimizer : public GraphOptimizer {
bool BypassingNodeIsBeneficial(
const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
const std::vector<NodeDef*>& output_nodes) const;
int NumEdgesIfBypassed(const NodeDef& node,
const std::vector<NodeDef*>& output_nodes) const;
// Returns true if node is not an Identity node or if it is an Identity
// that is safe to remove.
bool SafeToRemoveIdentity(const NodeDef& node) const;

View File

@ -634,7 +634,7 @@ TEST_F(DependencyOptimizerTest, IdentityInputs) {
EXPECT_EQ("s:1", output.node(5).input(0));
}
TEST_F(DependencyOptimizerTest, IdentityN) {
TEST_F(DependencyOptimizerTest, RemoveIdentityN_SwitchInput) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
@ -643,8 +643,6 @@ TEST_F(DependencyOptimizerTest, IdentityN) {
// IdentityN nodes to be removed.
auto id_f = ops::IdentityN(scope.WithOpName("id_f"), {s.output_false});
auto id_t = ops::IdentityN(scope.WithOpName("id_t"), {s.output_true});
// IdentityN node that can't be removed.
auto id_b =
ops::IdentityN(scope.WithOpName("id_b"), {s.output_false, s.output_true});
@ -663,22 +661,50 @@ TEST_F(DependencyOptimizerTest, IdentityN) {
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
EXPECT_EQ(9, output.node_size());
EXPECT_EQ("out1", output.node(5).name());
EXPECT_EQ(1, output.node(5).input_size());
EXPECT_EQ("s", output.node(5).input(0));
EXPECT_EQ(8, output.node_size());
EXPECT_EQ("out2", output.node(6).name());
EXPECT_EQ(1, output.node(6).input_size());
EXPECT_EQ("s:1", output.node(6).input(0));
auto out1_node = output.node(7);
EXPECT_EQ("out1", out1_node.name());
EXPECT_EQ(1, out1_node.input_size());
EXPECT_EQ("s", out1_node.input(0));
EXPECT_EQ("out3", output.node(7).name());
EXPECT_EQ(1, output.node(7).input_size());
EXPECT_EQ("id_b", output.node(7).input(0));
auto out2_node = output.node(4);
EXPECT_EQ("out2", out2_node.name());
EXPECT_EQ(1, out2_node.input_size());
EXPECT_EQ("s:1", out2_node.input(0));
EXPECT_EQ("out4", output.node(8).name());
EXPECT_EQ(1, output.node(8).input_size());
EXPECT_EQ("id_b:1", output.node(8).input(0));
auto out3_node = output.node(5);
EXPECT_EQ("out3", out3_node.name());
EXPECT_EQ(1, out3_node.input_size());
EXPECT_EQ("s", out3_node.input(0));
auto out4_node = output.node(6);
EXPECT_EQ("out4", out4_node.name());
EXPECT_EQ(1, out4_node.input_size());
EXPECT_EQ("s:1", out4_node.input(0));
}
TEST_F(DependencyOptimizerTest, DoNotRemoveIdentityNWithControlDependency) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
Output input1 = ops::Placeholder(scope.WithOpName("input1"), DT_BOOL);
Output input2 = ops::Const(scope.WithOpName("input2"), {1, 2});
auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {input1, input2});
Output out1 = ops::Identity(scope.WithOpName("out1"), id_n[0]);
Output out2 = ops::Identity(scope.WithOpName("out2"), id_n[1]);
auto out3 =
ops::NoOp(scope.WithOpName("out3").WithControlDependencies(id_n[1]));
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
item.fetch = {"out1", "out2", "out3"};
DependencyOptimizer optimizer;
GraphDef optimized_graph_def;
Status status = optimizer.Optimize(nullptr, item, &optimized_graph_def);
TF_EXPECT_OK(status);
EXPECT_EQ(6, optimized_graph_def.node_size());
}
TEST_F(DependencyOptimizerTest,