Extend Identity optimizations to IdentityN.
PiperOrigin-RevId: 219327001
This commit is contained in:
parent
55484438d7
commit
507c566376
@ -57,7 +57,7 @@ bool RemoveInput(NodeDef* node, const string& input, NodeMap* node_map) {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
|
bool DependencyOptimizer::SafeToRemoveIdentity(const NodeDef& node) const {
|
||||||
if (!IsIdentity(node) && !IsIdentityNSingleInput(node)) {
|
if (!IsIdentity(node) && !IsIdentityN(node)) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,15 +133,56 @@ bool DependencyOptimizer::SafeToConvertToNoOp(const NodeDef& node) const {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int DependencyOptimizer::NumEdgesIfBypassed(
|
||||||
|
const NodeDef& node, const std::vector<NodeDef*>& output_nodes) const {
|
||||||
|
const bool is_multi_input_identity_n =
|
||||||
|
IsIdentityN(node) && !IsIdentityNSingleInput(node);
|
||||||
|
const int num_outputs = output_nodes.size();
|
||||||
|
const int num_inputs = node.input_size();
|
||||||
|
|
||||||
|
if (is_multi_input_identity_n) {
|
||||||
|
// multi-input identity_n with input/output control dependencies will likely
|
||||||
|
// increase number of edges after optimization.
|
||||||
|
int num_edges_if_bypassed(0);
|
||||||
|
for (string input_node_name : node.input()) {
|
||||||
|
if (IsControlInput(input_node_name)) {
|
||||||
|
num_edges_if_bypassed += num_outputs;
|
||||||
|
} else {
|
||||||
|
++num_edges_if_bypassed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto consumer : output_nodes) {
|
||||||
|
for (int j = 0; j < consumer->input_size(); ++j) {
|
||||||
|
const string& consumer_input = consumer->input(j);
|
||||||
|
int consumer_input_pos;
|
||||||
|
StringPiece consumer_input_node_name =
|
||||||
|
ParseNodeNameAsStringPiece(consumer_input, &consumer_input_pos);
|
||||||
|
if (consumer_input_node_name == node.name()) {
|
||||||
|
if (IsControlInput(consumer_input)) {
|
||||||
|
num_edges_if_bypassed += num_inputs;
|
||||||
|
} else {
|
||||||
|
++num_edges_if_bypassed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return num_edges_if_bypassed;
|
||||||
|
} else {
|
||||||
|
return num_inputs * num_outputs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool DependencyOptimizer::BypassingNodeIsBeneficial(
|
bool DependencyOptimizer::BypassingNodeIsBeneficial(
|
||||||
const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
|
const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
|
||||||
const std::vector<NodeDef*>& output_nodes) const {
|
const std::vector<NodeDef*>& output_nodes) const {
|
||||||
const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node);
|
const bool is_identity = IsIdentity(node) || IsIdentityNSingleInput(node);
|
||||||
|
const bool is_multi_input_identity_n =
|
||||||
|
IsIdentityN(node) && !IsIdentityNSingleInput(node);
|
||||||
const int num_outputs = output_nodes.size();
|
const int num_outputs = output_nodes.size();
|
||||||
const int num_inputs = node.input_size();
|
const int num_inputs = node.input_size();
|
||||||
|
|
||||||
// Don't increase the number of edges in the graph.
|
if (NumEdgesIfBypassed(node, output_nodes) > num_inputs + num_outputs) {
|
||||||
if (num_inputs * num_outputs > num_inputs + num_outputs) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -166,7 +207,9 @@ bool DependencyOptimizer::BypassingNodeIsBeneficial(
|
|||||||
for (NodeDef* output_node : output_nodes) {
|
for (NodeDef* output_node : output_nodes) {
|
||||||
num_cross_out += static_cast<int>(output_node->device() != node_dev);
|
num_cross_out += static_cast<int>(output_node->device() != node_dev);
|
||||||
}
|
}
|
||||||
if (is_identity && num_cross_in > 0 && num_cross_out > 0) {
|
|
||||||
|
if ((is_identity || is_multi_input_identity_n) && num_cross_in > 0 &&
|
||||||
|
num_cross_out > 0) {
|
||||||
// This identity node follows a device crossing, so it might be
|
// This identity node follows a device crossing, so it might be
|
||||||
// following a _Recv node after partioning. Do not remove such nodes,
|
// following a _Recv node after partioning. Do not remove such nodes,
|
||||||
// unless they only have consumers on the same device as themselves.
|
// unless they only have consumers on the same device as themselves.
|
||||||
@ -194,6 +237,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
|
|||||||
NodeDef* node = optimized_graph_->mutable_node(node_idx);
|
NodeDef* node = optimized_graph_->mutable_node(node_idx);
|
||||||
const bool is_noop = IsNoOp(*node);
|
const bool is_noop = IsNoOp(*node);
|
||||||
const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node);
|
const bool is_identity = IsIdentity(*node) || IsIdentityNSingleInput(*node);
|
||||||
|
const bool is_multi_input_identity =
|
||||||
|
IsIdentityN(*node) && !IsIdentityNSingleInput(*node);
|
||||||
const string node_name = node->name();
|
const string node_name = node->name();
|
||||||
// Constant nodes with no input control dependency are always executed early,
|
// Constant nodes with no input control dependency are always executed early,
|
||||||
// so we can prune all their output control dependencies.
|
// so we can prune all their output control dependencies.
|
||||||
@ -315,7 +360,8 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
|
|||||||
// y --^> | | --^> b /\ +---+
|
// y --^> | | --^> b /\ +---+
|
||||||
// +----------+ y --^> b
|
// +----------+ y --^> b
|
||||||
|
|
||||||
if (is_noop || (is_identity && SafeToRemoveIdentity(*node))) {
|
if (is_noop || ((is_identity || is_multi_input_identity) &&
|
||||||
|
SafeToRemoveIdentity(*node))) {
|
||||||
const auto& output_node_set = node_map_->GetOutputs(node_name);
|
const auto& output_node_set = node_map_->GetOutputs(node_name);
|
||||||
const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
|
const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
|
||||||
output_node_set.end());
|
output_node_set.end());
|
||||||
@ -343,11 +389,11 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
|
|||||||
const NodeDef* input = input_nodes[i];
|
const NodeDef* input = input_nodes[i];
|
||||||
// Forward dependency from input to consumer if it doesn't already
|
// Forward dependency from input to consumer if it doesn't already
|
||||||
// depend on it.
|
// depend on it.
|
||||||
if (is_identity && i == 0) {
|
if ((is_identity && i == 0) ||
|
||||||
|
(is_multi_input_identity && !IsControlInput(node->input(i)))) {
|
||||||
// Replace regular input from Identity node.
|
// Replace regular input from Identity node.
|
||||||
bool found_input = false;
|
|
||||||
string new_input;
|
string new_input;
|
||||||
const string& input_to_forward = node->input(0);
|
const string& input_to_forward = node->input(i);
|
||||||
CHECK(!IsControlInput(input_to_forward));
|
CHECK(!IsControlInput(input_to_forward));
|
||||||
for (int j = 0; j < consumer->input_size(); ++j) {
|
for (int j = 0; j < consumer->input_size(); ++j) {
|
||||||
const string& old_input = consumer->input(j);
|
const string& old_input = consumer->input(j);
|
||||||
@ -355,22 +401,19 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
|
|||||||
StringPiece old_input_node_name =
|
StringPiece old_input_node_name =
|
||||||
ParseNodeNameAsStringPiece(old_input, &old_input_pos);
|
ParseNodeNameAsStringPiece(old_input, &old_input_pos);
|
||||||
if (old_input_node_name == node_name) {
|
if (old_input_node_name == node_name) {
|
||||||
if (old_input_pos >= 0) {
|
if (old_input_pos == i) {
|
||||||
// Regular input
|
// Regular input
|
||||||
new_input = input_to_forward;
|
new_input = input_to_forward;
|
||||||
node_map_->UpdateInput(consumer->name(), old_input, new_input);
|
node_map_->UpdateInput(consumer->name(), old_input, new_input);
|
||||||
consumer->set_input(j, new_input);
|
consumer->set_input(j, new_input);
|
||||||
found_input = true;
|
} else if (old_input_pos == -1) {
|
||||||
} else {
|
|
||||||
// Control dependency
|
// Control dependency
|
||||||
new_input = AsControlDependency(NodeName(input_to_forward));
|
new_input = AsControlDependency(NodeName(input_to_forward));
|
||||||
node_map_->UpdateInput(consumer->name(), old_input, new_input);
|
node_map_->UpdateInput(consumer->name(), old_input, new_input);
|
||||||
consumer->set_input(j, new_input);
|
consumer->set_input(j, new_input);
|
||||||
found_input = true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK(found_input);
|
|
||||||
updated_consumer = true;
|
updated_consumer = true;
|
||||||
} else {
|
} else {
|
||||||
// Forward dependency from input to consumer if it doesn't already
|
// Forward dependency from input to consumer if it doesn't already
|
||||||
@ -415,7 +458,7 @@ Status DependencyOptimizer::OptimizeDependencies() {
|
|||||||
std::set<int> nodes_to_delete;
|
std::set<int> nodes_to_delete;
|
||||||
for (int i = 0; i < optimized_graph_->node_size(); ++i) {
|
for (int i = 0; i < optimized_graph_->node_size(); ++i) {
|
||||||
const NodeDef& node = optimized_graph_->node(i);
|
const NodeDef& node = optimized_graph_->node(i);
|
||||||
if (IsNoOp(node) || IsIdentity(node) || IsIdentityNSingleInput(node) ||
|
if (IsNoOp(node) || IsIdentity(node) || IsIdentityN(node) ||
|
||||||
IsConstant(node) || SafeToConvertToNoOp(node)) {
|
IsConstant(node) || SafeToConvertToNoOp(node)) {
|
||||||
nodes_to_simplify.PushBack(i);
|
nodes_to_simplify.PushBack(i);
|
||||||
}
|
}
|
||||||
|
@ -48,7 +48,8 @@ class DependencyOptimizer : public GraphOptimizer {
|
|||||||
bool BypassingNodeIsBeneficial(
|
bool BypassingNodeIsBeneficial(
|
||||||
const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
|
const NodeDef& node, const std::vector<NodeDef*>& input_nodes,
|
||||||
const std::vector<NodeDef*>& output_nodes) const;
|
const std::vector<NodeDef*>& output_nodes) const;
|
||||||
|
int NumEdgesIfBypassed(const NodeDef& node,
|
||||||
|
const std::vector<NodeDef*>& output_nodes) const;
|
||||||
// Returns true if node is not an Identity node or if it is an Identity
|
// Returns true if node is not an Identity node or if it is an Identity
|
||||||
// that is safe to remove.
|
// that is safe to remove.
|
||||||
bool SafeToRemoveIdentity(const NodeDef& node) const;
|
bool SafeToRemoveIdentity(const NodeDef& node) const;
|
||||||
|
@ -634,7 +634,7 @@ TEST_F(DependencyOptimizerTest, IdentityInputs) {
|
|||||||
EXPECT_EQ("s:1", output.node(5).input(0));
|
EXPECT_EQ("s:1", output.node(5).input(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DependencyOptimizerTest, IdentityN) {
|
TEST_F(DependencyOptimizerTest, RemoveIdentityN_SwitchInput) {
|
||||||
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||||
Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
|
Output b = ops::Placeholder(scope.WithOpName("b"), DT_BOOL);
|
||||||
Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
|
Output x = ops::RandomUniform(scope.WithOpName("x"), {1, 2}, DT_FLOAT);
|
||||||
@ -643,8 +643,6 @@ TEST_F(DependencyOptimizerTest, IdentityN) {
|
|||||||
// IdentityN nodes to be removed.
|
// IdentityN nodes to be removed.
|
||||||
auto id_f = ops::IdentityN(scope.WithOpName("id_f"), {s.output_false});
|
auto id_f = ops::IdentityN(scope.WithOpName("id_f"), {s.output_false});
|
||||||
auto id_t = ops::IdentityN(scope.WithOpName("id_t"), {s.output_true});
|
auto id_t = ops::IdentityN(scope.WithOpName("id_t"), {s.output_true});
|
||||||
|
|
||||||
// IdentityN node that can't be removed.
|
|
||||||
auto id_b =
|
auto id_b =
|
||||||
ops::IdentityN(scope.WithOpName("id_b"), {s.output_false, s.output_true});
|
ops::IdentityN(scope.WithOpName("id_b"), {s.output_false, s.output_true});
|
||||||
|
|
||||||
@ -663,22 +661,50 @@ TEST_F(DependencyOptimizerTest, IdentityN) {
|
|||||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||||
TF_EXPECT_OK(status);
|
TF_EXPECT_OK(status);
|
||||||
|
|
||||||
EXPECT_EQ(9, output.node_size());
|
EXPECT_EQ(8, output.node_size());
|
||||||
EXPECT_EQ("out1", output.node(5).name());
|
|
||||||
EXPECT_EQ(1, output.node(5).input_size());
|
|
||||||
EXPECT_EQ("s", output.node(5).input(0));
|
|
||||||
|
|
||||||
EXPECT_EQ("out2", output.node(6).name());
|
auto out1_node = output.node(7);
|
||||||
EXPECT_EQ(1, output.node(6).input_size());
|
EXPECT_EQ("out1", out1_node.name());
|
||||||
EXPECT_EQ("s:1", output.node(6).input(0));
|
EXPECT_EQ(1, out1_node.input_size());
|
||||||
|
EXPECT_EQ("s", out1_node.input(0));
|
||||||
|
|
||||||
EXPECT_EQ("out3", output.node(7).name());
|
auto out2_node = output.node(4);
|
||||||
EXPECT_EQ(1, output.node(7).input_size());
|
EXPECT_EQ("out2", out2_node.name());
|
||||||
EXPECT_EQ("id_b", output.node(7).input(0));
|
EXPECT_EQ(1, out2_node.input_size());
|
||||||
|
EXPECT_EQ("s:1", out2_node.input(0));
|
||||||
|
|
||||||
EXPECT_EQ("out4", output.node(8).name());
|
auto out3_node = output.node(5);
|
||||||
EXPECT_EQ(1, output.node(8).input_size());
|
EXPECT_EQ("out3", out3_node.name());
|
||||||
EXPECT_EQ("id_b:1", output.node(8).input(0));
|
EXPECT_EQ(1, out3_node.input_size());
|
||||||
|
EXPECT_EQ("s", out3_node.input(0));
|
||||||
|
|
||||||
|
auto out4_node = output.node(6);
|
||||||
|
EXPECT_EQ("out4", out4_node.name());
|
||||||
|
EXPECT_EQ(1, out4_node.input_size());
|
||||||
|
EXPECT_EQ("s:1", out4_node.input(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DependencyOptimizerTest, DoNotRemoveIdentityNWithControlDependency) {
|
||||||
|
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
|
||||||
|
Output input1 = ops::Placeholder(scope.WithOpName("input1"), DT_BOOL);
|
||||||
|
Output input2 = ops::Const(scope.WithOpName("input2"), {1, 2});
|
||||||
|
|
||||||
|
auto id_n = ops::IdentityN(scope.WithOpName("id_n"), {input1, input2});
|
||||||
|
Output out1 = ops::Identity(scope.WithOpName("out1"), id_n[0]);
|
||||||
|
Output out2 = ops::Identity(scope.WithOpName("out2"), id_n[1]);
|
||||||
|
auto out3 =
|
||||||
|
ops::NoOp(scope.WithOpName("out3").WithControlDependencies(id_n[1]));
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
|
||||||
|
item.fetch = {"out1", "out2", "out3"};
|
||||||
|
|
||||||
|
DependencyOptimizer optimizer;
|
||||||
|
GraphDef optimized_graph_def;
|
||||||
|
Status status = optimizer.Optimize(nullptr, item, &optimized_graph_def);
|
||||||
|
TF_EXPECT_OK(status);
|
||||||
|
|
||||||
|
EXPECT_EQ(6, optimized_graph_def.node_size());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DependencyOptimizerTest,
|
TEST_F(DependencyOptimizerTest,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user