Extend Identity optimizations to IdentityN.
PiperOrigin-RevId: 219327001
This commit is contained in:
parent
55484438d7
commit
507c566376
@ -57,7 +57,7 @@ bool RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) {
|
||||
} // namespace
|
||||
|
||||
bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
|
||||
if (!IsIdentity(node) && !IsIdentityNSingleInput(node)) {
|
||||
if (!IsIdentity(node) && !IsIdentityN(node)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -133,15 +133,56 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
int DependencyOptimizer::NumEdgesIfBypassed(
|
||||
const NodeDef& node, const std::vector<NodeDef*>& output_nodes) const {
|
||||
const bool is_multi_input_identity_n =
|
||||
IsIdentityN(node) && !IsIdentityNSingleInput(node);
|
||||
const int num_outputs = output_nodes.size();
|
||||
const int num_inputs = node.input_size();
|
||||
|
||||
if (is_multi_input_identity_n) {
|
||||
// multi-input identity_n with input/output control dependencies will likely
|
||||
// increase number of edges after optimization.
|
||||
int num_edges_if_bypassed(0);
|
||||
for (string input_node_name : node.input()) {
|
||||
if (IsControlInput(input_node_name)) {
|
||||
num_edges_if_bypassed += num_outputs;
|
||||
} else {
|
||||
++num_edges_if_bypassed;
|
||||
}
|
||||
}
|
||||
|
||||
for (auto consumer : output_nodes) {
|
||||
for (int j = 0; j < consumer->input_size(); ++j) {
|
||||
const string& consumer_input = consumer->input(j);
|
||||
int consumer_input_pos;
|
||||
StringPiece consumer_input_node_name =
|
||||
ParseNodeNameAsStringPiece(consumer_input, &consumer_input_pos);
|
||||
if (consumer_input_node_name == node.name()) {
|
||||
if (IsControlInput(consumer_input)) {
|
||||
num_edges_if_bypassed += num_inputs;
|
||||
} else {
|
||||
++num_edges_if_bypassed;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return num_edges_if_bypassed;
|
||||
} else {
|
||||
return num_inputs * num_outputs;
|
||||
}
|
||||
}
|
||||
|
||||
bool DependencyOptimizer::BypassingNodeIsBeneficial(
|
||||
const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
|
||||
const std::vector<NodeDef*>& output_nodes) const {
|
||||
const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node);
|
||||
const bool is_multi_input_identity_n =
|
||||
IsIdentityN(node) && !IsIdentityNSingleInput(node);
|
||||
const int num_outputs = output_nodes.size();
|
||||
const int num_inputs = node.input_size();
|
||||
|
||||
// Don't increase the number of edges in the graph.
|
||||
if (num_inputs * num_outputs > num_inputs + num_outputs) {
|
||||
if (NumEdgesIfBypassed(node, output_nodes) > num_inputs + num_outputs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -166,7 +207,9 @@ bool DependencyOptimizer::BypassingNodeIsBeneficial(
|
||||
for (NodeDef* output_node : output_nodes) {
|
||||
num_cross_out += static_cast<int>(output_node->device() != node_dev);
|
||||
}
|
||||
if (is_identity && num_cross_in > 0 && num_cross_out > 0) {
|
||||
|
||||
if ((is_identity || is_multi_input_identity_n) && num_cross_in > 0 &&
|
||||
num_cross_out > 0) {
|
||||
// This identity node follows a device crossing, so it might be
|
||||
// following a _Recv node after partioning. Do not remove such nodes,
|
||||
// unless they only have consumers on the same device as themselves.
|
||||
@ -194,6 +237,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
|
||||
NodeDef* node = optimized_graph_->mutable_node(node_idx);
|
||||
const bool is_noop = IsNoOp(*node);
|
||||
const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node);
|
||||
const bool is_multi_input_identity =
|
||||
IsIdentityN(*node) && !IsIdentityNSingleInput(*node);
|
||||
const string node_name = node->name();
|
||||
// Constant nodes with no input control dependency are always executed early,
|
||||
// so we can prune all their output control dependencies.
|
||||
@ -315,7 +360,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
|
||||
// y --^> | | --^> b /\ +---+
|
||||
// +----------+ y --^> b
|
||||
|
||||
if (is_noop || (is_identity && SafeToRemoveIdentity(*node))) {
|
||||
if (is_noop || ((is_identity || is_multi_input_identity) &&
|
||||
SafeToRemoveIdentity(*node))) {
|
||||
const auto& output_node_set = node_map_->GetOutputs(node_name);
|
||||
const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
|
||||
output_node_set.end());
|
||||
@ -343,11 +389,11 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
|
||||
const NodeDef* input = input_nodes[i];
|
||||
// Forward dependency from input to consumer if it doesn't already
|
||||
// depend on it.
|
||||
if (is_identity && i == 0) {
|
||||
if ((is_identity && i == 0) ||
|
||||
(is_multi_input_identity && !IsControlInput(node->input(i)))) {
|
||||
// Replace regular input from Identity node.
|
||||
bool found_input = false;
|
||||
string new_input;
|
||||
const string& input_to_forward = node->input(0);
|
||||
const string& input_to_forward = node->input(i);
|
||||
CHECK(!IsControlInput(input_to_forward));
|
||||
for (int j = 0; j < consumer->input_size(); ++j) {
|
||||
const string& old_input = consumer->input(j);
|
||||
@ -355,22 +401,19 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
|
||||
StringPiece old_input_node_name =
|
||||
ParseNodeNameAsStringPiece(old_input, &old_input_pos);
|
||||
if (old_input_node_name == node_name) {
|
||||
if (old_input_pos >= 0) {
|
||||
if (old_input_pos == i) {
|
||||
// Regular input
|
||||
new_input = input_to_forward;
|
||||
node_map_->UpdateInput(consumer->name(), old_input, new_input);
|
||||
consumer->set_input(j, new_input);
|
||||
found_input = true;
|
||||
} else {
|
||||
} else if (old_input_pos == -1) {
|
||||
// Control dependency
|
||||
new_input = AsControlDependency(NodeName(input_to_forward));
|
||||
node_map_->UpdateInput(consumer->name(), old_input, new_input);
|
||||
consumer->set_input(j, new_input);
|
||||
found_input = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
CHECK(found_input);
|
||||
updated_consumer = true;
|
||||
} else {
|
||||
// Forward dependency from input to consumer if it doesn't already
|
||||
@ -415,7 +458,7 @@ Status DependencyOptimizer::OptimizeDependencies() {
|
||||
std::set<int> nodes_to_delete;
|
||||
for (int i = 0; i < optimized_graph_->node_size(); ++i) {
|
||||
const NodeDef& node = optimized_graph_->node(i);
|
||||
if (IsNoOp(node) || IsIdentity(node) || IsIdentityNSingleInput(node) ||
|
||||
if (IsNoOp(node) || IsIdentity(node) || IsIdentityN(node) ||
|
||||
IsConstant(node) || SafeToConvertToNoOp(node)) {
|
||||
nodes_to_simplify.PushBack(i);
|
||||
}
|
||||
|
@ -48,7 +48,8 @@ class DependencyOptimizer : public GraphOptimizer {
|
||||
bool BypassingNodeIsBeneficial(
|
||||
const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
|
||||
const std::vector<NodeDef*>& output_nodes) const;
|
||||
|
||||
int NumEdgesIfBypassed(const NodeDef& node,
|
||||
const std::vector<NodeDef*>& output_nodes) const;
|
||||
// Returns true if node is not an Identity node or if it is an Identity
|
||||
// that is safe to remove.
|
||||
bool SafeToRemoveIdentity(const NodeDef& node) const;
|
||||
|
@ -634,7 +634,7 @@ TEST_F(DependencyOptimizerTest, IdentityInputs) {
|
||||
EXPECT_EQ("s:1", output.node(5).input(0));
|
||||
}
|
||||
|
||||
TEST_F(DependencyOptimizerTest, IdentityN) {
|
||||
TEST_F(DependencyOptimizerTest, RemoveIdentityN_SwitchInput) {
|
||||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||
Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
|
||||
Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
|
||||
@ -643,8 +643,6 @@ TEST_F(DependencyOptimizerTest, IdentityN) {
|
||||
// IdentityN nodes to be removed.
|
||||
auto id_f = ops::IdentityN(scope.WithOpName("id_f"), {s.output_false});
|
||||
auto id_t = ops::IdentityN(scope.WithOpName("id_t"), {s.output_true});
|
||||
|
||||
// IdentityN node that can't be removed.
|
||||
auto id_b =
|
||||
ops::IdentityN(scope.WithOpName("id_b"), {s.output_false, s.output_true});
|
||||
|
||||
@ -663,22 +661,50 @@ TEST_F(DependencyOptimizerTest, IdentityN) {
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(9, output.node_size());
|
||||
EXPECT_EQ("out1", output.node(5).name());
|
||||
EXPECT_EQ(1, output.node(5).input_size());
|
||||
EXPECT_EQ("s", output.node(5).input(0));
|
||||
EXPECT_EQ(8, output.node_size());
|
||||
|
||||
EXPECT_EQ("out2", output.node(6).name());
|
||||
EXPECT_EQ(1, output.node(6).input_size());
|
||||
EXPECT_EQ("s:1", output.node(6).input(0));
|
||||
auto out1_node = output.node(7);
|
||||
EXPECT_EQ("out1", out1_node.name());
|
||||
EXPECT_EQ(1, out1_node.input_size());
|
||||
EXPECT_EQ("s", out1_node.input(0));
|
||||
|
||||
EXPECT_EQ("out3", output.node(7).name());
|
||||
EXPECT_EQ(1, output.node(7).input_size());
|
||||
EXPECT_EQ("id_b", output.node(7).input(0));
|
||||
auto out2_node = output.node(4);
|
||||
EXPECT_EQ("out2", out2_node.name());
|
||||
EXPECT_EQ(1, out2_node.input_size());
|
||||
EXPECT_EQ("s:1", out2_node.input(0));
|
||||
|
||||
EXPECT_EQ("out4", output.node(8).name());
|
||||
EXPECT_EQ(1, output.node(8).input_size());
|
||||
EXPECT_EQ("id_b:1", output.node(8).input(0));
|
||||
auto out3_node = output.node(5);
|
||||
EXPECT_EQ("out3", out3_node.name());
|
||||
EXPECT_EQ(1, out3_node.input_size());
|
||||
EXPECT_EQ("s", out3_node.input(0));
|
||||
|
||||
auto out4_node = output.node(6);
|
||||
EXPECT_EQ("out4", out4_node.name());
|
||||
EXPECT_EQ(1, out4_node.input_size());
|
||||
EXPECT_EQ("s:1", out4_node.input(0));
|
||||
}
|
||||
|
||||
TEST_F(DependencyOptimizerTest, DoNotRemoveIdentityNWithControlDependency) {
|
||||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||
Output input1 = ops::Placeholder(scope.WithOpName("input1"), DT_BOOL);
|
||||
Output input2 = ops::Const(scope.WithOpName("input2"), {1, 2});
|
||||
|
||||
auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {input1, input2});
|
||||
Output out1 = ops::Identity(scope.WithOpName("out1"), id_n[0]);
|
||||
Output out2 = ops::Identity(scope.WithOpName("out2"), id_n[1]);
|
||||
auto out3 =
|
||||
ops::NoOp(scope.WithOpName("out3").WithControlDependencies(id_n[1]));
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||
item.fetch = {"out1", "out2", "out3"};
|
||||
|
||||
DependencyOptimizer optimizer;
|
||||
GraphDef optimized_graph_def;
|
||||
Status status = optimizer.Optimize(nullptr, item, &optimized_graph_def);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(6, optimized_graph_def.node_size());
|
||||
}
|
||||
|
||||
TEST_F(DependencyOptimizerTest,
|
||||
|
Loading…
Reference in New Issue
Block a user