From 7141c4280873036fbdcccacf696c0fe12b5538c7 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Tue, 12 Feb 2019 10:34:06 -0800 Subject: [PATCH] [Grappler] Don't remove constant feed nodes in LoopOptimizer RemoveDeadBranches. PiperOrigin-RevId: 233633382 --- .../grappler/optimizers/loop_optimizer.cc | 36 +- .../core/grappler/optimizers/loop_optimizer.h | 4 +- .../optimizers/loop_optimizer_test.cc | 309 ++++++++++++++---- 3 files changed, 280 insertions(+), 69 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc index cf5e4db29f4..54776e7f80c 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -581,8 +581,19 @@ Status EvaluateBoolOpForConstantOperands(const NodeDef& op_node, return Status::OK(); } +// TODO(lyandy): Consolidate with ConstantFolding implementation. +bool IsReallyConstant(const NodeDef& node, + const absl::flat_hash_set& feed_nodes) { + if (!IsConstant(node)) { + return false; + } + // If the node is fed it's not constant anymore. + return feed_nodes.find(node.name()) == feed_nodes.end(); +} + Status CheckForDeadFanout(const MutableGraphView& view, const NodeDef& switch_node, const NodeMap& node_map, + const absl::flat_hash_set& feed_nodes, DeviceBase* cpu_device, ResourceMgr* resource_mgr, bool* has_dead_fanout, int* dead_fanout) { *has_dead_fanout = false; @@ -591,7 +602,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, view.GetRegularFanin(switch_loopcond_port).node; // CASE 1: Control is a constant. - if (IsConstant(*switch_predicate)) { + if (IsReallyConstant(*switch_predicate, feed_nodes)) { Tensor selector; CHECK(selector.FromProto(switch_predicate->attr().at("value").tensor())); *has_dead_fanout = true; @@ -630,7 +641,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, if (IsMerge(*node)) { merge_node = node; } - if (IsConstant(*node)) { + if (IsReallyConstant(*node, feed_nodes)) { constant_ctrl_input = node; constant_index = i; } @@ -646,7 +657,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, if (IsEnter(*node)) { enter_node = node; } - if (IsConstant(*node)) { + if (IsReallyConstant(*node, feed_nodes)) { constant_init_node = node; } } @@ -654,7 +665,7 @@ Status CheckForDeadFanout(const MutableGraphView& view, if (constant_init_node != nullptr) return Status::OK(); for (const auto& input : enter_node->input()) { NodeDef* node = node_map.GetNode(input); - if (IsConstant(*node)) { + if (IsReallyConstant(*node, feed_nodes)) { constant_init_node = node; } } @@ -710,8 +721,12 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // TODO(srjoglekar): Figure out if we can optimize NodeMap creations across // optimizer passes. NodeMap node_map(optimized_graph); - TF_RETURN_IF_ERROR( - RemoveDeadBranches(item.NodesToPreserve(), node_map, optimized_graph)); + absl::flat_hash_set feed_nodes; + for (const auto& feed : item.feed) { + feed_nodes.insert(NodeName(feed.first)); + } + TF_RETURN_IF_ERROR(RemoveDeadBranches(item.NodesToPreserve(), node_map, + feed_nodes, optimized_graph)); } return Status::OK(); @@ -719,7 +734,8 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, Status LoopOptimizer::RemoveDeadBranches( const std::unordered_set& nodes_to_preserve, - const NodeMap& node_map, GraphDef* optimized_graph) { + const NodeMap& node_map, const absl::flat_hash_set& feed_nodes, + GraphDef* optimized_graph) { std::unordered_set dead_nodes; std::unordered_map> dead_merge_inputs; // TODO(bsteiner): also rewrite switches as identity. For now we just record @@ -737,9 +753,9 @@ Status LoopOptimizer::RemoveDeadBranches( int dead_fanout; bool has_dead_fanout; - TF_RETURN_IF_ERROR(CheckForDeadFanout(view, node, node_map, cpu_device_, - resource_mgr_.get(), &has_dead_fanout, - &dead_fanout)); + TF_RETURN_IF_ERROR(CheckForDeadFanout(view, node, node_map, feed_nodes, + cpu_device_, resource_mgr_.get(), + &has_dead_fanout, &dead_fanout)); if (!has_dead_fanout) { continue; } diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h index d467237a9a7..7fa1976f348 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer.h +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h @@ -60,7 +60,9 @@ class LoopOptimizer : public GraphOptimizer { }; Status RemoveDeadBranches(const std::unordered_set& nodes_to_preserve, - const NodeMap& node_map, GraphDef* optimized_graph); + const NodeMap& node_map, + const absl::flat_hash_set& feed_nodes, + GraphDef* optimized_graph); RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc index 587767c23c3..9a6d6272dfd 100644 --- a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc @@ -504,11 +504,11 @@ void VerifyGraphsEqual(const GraphDef& original_graph, for (int i = 0; i < original_graph.node_size(); ++i) { const NodeDef& original = original_graph.node(i); const NodeDef& optimized = optimized_graph.node(i); - EXPECT_EQ(original.name(), optimized.name()) << func; - EXPECT_EQ(original.op(), optimized.op()) << func; - EXPECT_EQ(original.input_size(), optimized.input_size()) << func; + EXPECT_EQ(optimized.name(), original.name()) << func; + EXPECT_EQ(optimized.op(), original.op()) << func; + ASSERT_EQ(optimized.input_size(), original.input_size()) << func; for (int j = 0; j < original.input_size(); ++j) { - EXPECT_EQ(original.input(j), optimized.input(j)) << func; + EXPECT_EQ(optimized.input(j), original.input(j)) << func; } } } @@ -528,7 +528,7 @@ TEST_F(LoopOptimizerTest, NoOp) { VerifyGraphsEqual(item.graph, output, __FUNCTION__); } -TEST_F(LoopOptimizerTest, RemovePush_NoOp) { +TEST_F(LoopOptimizerTest, RemovePushNoOp) { GrapplerItem item; GraphDef& graph = item.graph; AddSimpleNode("c", "Const", {}, &graph); @@ -557,7 +557,7 @@ TEST_F(LoopOptimizerTest, RemovePush_NoOp) { VerifyGraphsEqual(item.graph, output, __FUNCTION__); } -TEST_F(LoopOptimizerTest, RemovePush_NoPopButStackLives) { +TEST_F(LoopOptimizerTest, RemovePushNoPopButStackLives) { GrapplerItem item; GraphDef& graph = item.graph; AddSimpleNode("c", "Const", {}, &graph); @@ -609,32 +609,32 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) { Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(13, output.node_size()); + EXPECT_EQ(output.node_size(), 13); for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); if (node.name() == "push1") { - EXPECT_EQ("Identity", node.op()); - EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("c", node.input(0)); - EXPECT_EQ("^stack1", node.input(1)); + EXPECT_EQ(node.op(), "Identity"); + ASSERT_EQ(node.input_size(), 2); + EXPECT_EQ(node.input(0), "c"); + EXPECT_EQ(node.input(1), "^stack1"); } else if (node.name() == "push2") { - EXPECT_EQ("Identity", node.op()); - EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("enter_c", node.input(0)); - EXPECT_EQ("^enter_stack2", node.input(1)); + EXPECT_EQ(node.op(), "Identity"); + ASSERT_EQ(node.input_size(), 2); + EXPECT_EQ(node.input(0), "enter_c"); + EXPECT_EQ(node.input(1), "^enter_stack2"); } else if (node.name() == "push3") { - EXPECT_EQ("Identity", node.op()); - EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("c", node.input(0)); - EXPECT_EQ("^stack3", node.input(1)); + EXPECT_EQ(node.op(), "Identity"); + ASSERT_EQ(node.input_size(), 2); + EXPECT_EQ(node.input(0), "c"); + EXPECT_EQ(node.input(1), "^stack3"); } else { const NodeDef& orig_node = item.graph.node(i); - EXPECT_EQ(orig_node.ShortDebugString(), node.ShortDebugString()); + EXPECT_EQ(node.ShortDebugString(), orig_node.ShortDebugString()); } } } -TEST_F(LoopOptimizerTest, RemoveDeadBranches_ConstantCondition) { +TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantCondition) { Scope scope = Scope::NewRootScope(); Output v_in = ops::Variable(scope.WithOpName("v_in"), {3}, DT_FLOAT); @@ -691,57 +691,57 @@ TEST_F(LoopOptimizerTest, RemoveDeadBranches_ConstantCondition) { for (const NodeDef& node : output.node()) { // These nodes should have been pruned - EXPECT_NE("Square1", node.name()); - EXPECT_NE("Sqrt2", node.name()); - EXPECT_NE("m5", node.name()); - EXPECT_NE("m7", node.name()); + EXPECT_NE(node.name(), "Square1"); + EXPECT_NE(node.name(), "Sqrt2"); + EXPECT_NE(node.name(), "m5"); + EXPECT_NE(node.name(), "m7"); if (node.name() == "m1") { // sqrt1 is dead - EXPECT_EQ("Identity", node.op()); - EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("square1", node.input(0)); + EXPECT_EQ(node.op(), "Identity"); + ASSERT_EQ(node.input_size(), 1); + EXPECT_EQ(node.input(0), "square1"); } else if (node.name() == "m2") { // both inputs are alive - EXPECT_EQ("Merge", node.op()); - EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("v_in", node.input(0)); - EXPECT_EQ("square1", node.input(1)); + EXPECT_EQ(node.op(), "Merge"); + ASSERT_EQ(node.input_size(), 2); + EXPECT_EQ(node.input(0), "v_in"); + EXPECT_EQ(node.input(1), "square1"); } else if (node.name() == "m3") { // sqrt1 is dead - EXPECT_EQ("Identity", node.op()); - EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("v_in", node.input(0)); + EXPECT_EQ(node.op(), "Identity"); + ASSERT_EQ(node.input_size(), 1); + EXPECT_EQ(node.input(0), "v_in"); } else if (node.name() == "m4") { // both inputs are alive - EXPECT_EQ("Merge", node.op()); - EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("square1", node.input(0)); - EXPECT_EQ("sqrt2", node.input(1)); + EXPECT_EQ(node.op(), "Merge"); + ASSERT_EQ(node.input_size(), 2); + EXPECT_EQ(node.input(0), "square1"); + EXPECT_EQ(node.input(1), "sqrt2"); } else if (node.name() == "m6") { // both inputs are alive and the control dependency can get triggered - EXPECT_EQ("Merge", node.op()); - EXPECT_EQ(3, node.input_size()); - EXPECT_EQ("v_in", node.input(0)); - EXPECT_EQ("square1", node.input(1)); - EXPECT_EQ("^sqrt2", node.input(2)); + EXPECT_EQ(node.op(), "Merge"); + ASSERT_EQ(node.input_size(), 3); + EXPECT_EQ(node.input(0), "v_in"); + EXPECT_EQ(node.input(1), "square1"); + EXPECT_EQ(node.input(2), "^sqrt2"); } else if (node.name() == "m8") { // The node is to be preserved because of a fetch - EXPECT_EQ("Merge", node.op()); - EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("id1", node.input(0)); - EXPECT_EQ("id2", node.input(1)); + EXPECT_EQ(node.op(), "Merge"); + ASSERT_EQ(node.input_size(), 2); + EXPECT_EQ(node.input(0), "id1"); + EXPECT_EQ(node.input(1), "id2"); } else if (node.name() == "m9") { // The node is to be preserved because of a fetch - EXPECT_EQ("Merge", node.op()); - EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("id3", node.input(0)); - EXPECT_EQ("id4", node.input(1)); + EXPECT_EQ(node.op(), "Merge"); + ASSERT_EQ(2, node.input_size()); + EXPECT_EQ(node.input(0), "id3"); + EXPECT_EQ(node.input(1), "id4"); } } } -TEST_F(LoopOptimizerTest, RemoveDeadBranches_FullyRemoveDeadBranches) { +TEST_F(LoopOptimizerTest, RemoveDeadBranchesFullyRemoveDeadBranches) { const string gdef_ascii = R"EOF( node { name: "episodicreplaybuffer_add_readvariableop_resource" @@ -1153,7 +1153,7 @@ versions { << "Merge node was deleted, but it shouldn't have been."; } -TEST_F(LoopOptimizerTest, RemoveDeadBranches_ZeroIterWhile) { +TEST_F(LoopOptimizerTest, RemoveDeadBranchesZeroIterWhile) { const string gdef_ascii = R"EOF( node { name: "Const" @@ -1358,15 +1358,15 @@ versions { CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph)); item.fetch = {"while/Exit"}; auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - EXPECT_EQ(1, tensors_expected.size()); + ASSERT_EQ(tensors_expected.size(), 1); LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_CHECK_OK(status); auto tensors_got = EvaluateNodes(output, item.fetch); - EXPECT_EQ(1, tensors_got.size()); - test::ExpectTensorEqual(tensors_expected[0], tensors_got[0]); + ASSERT_EQ(tensors_got.size(), 1); + test::ExpectTensorEqual(tensors_got[0], tensors_expected[0]); int nodes_present = 0; for (const NodeDef& node : output.node()) { @@ -1382,7 +1382,200 @@ versions { } ++nodes_present; } - EXPECT_EQ(8, nodes_present); + EXPECT_EQ(nodes_present, 8); +} + +TEST_F(LoopOptimizerTest, RemoveDeadBranchesConstantFeed) { + const string gdef_ascii = R"EOF( +node { + name: "Const" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "I\'m a value!" + } + } + } +} +node { + name: "cond/Switch_1" + op: "Switch" + input: "Const" + input: "Const_1" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_STRING + } + } + attr { + key: "_class" + value { + list { + s: "loc:@Const" + } + } + } +} +node { + name: "Const_1" + op: "Const" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BOOL + tensor_shape { + } + bool_val: true + } + } + } +} +node { + name: "cond/Switch" + op: "Switch" + input: "Const_1" + input: "Const_1" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } +} +node { + name: "cond/switch_t" + op: "Identity" + input: "cond/Switch:1" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_BOOL + } + } +} +node { + name: "cond/Const" + op: "Const" + input: "^cond/switch_t" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "dtype" + value { + type: DT_STRING + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_STRING + tensor_shape { + dim { + size: 1 + } + } + string_val: "" + } + } + } +} +node { + name: "cond/Merge" + op: "Merge" + input: "cond/Switch_1" + input: "cond/Const" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_STRING + } + } +} +node { + name: "Identity" + op: "Identity" + input: "cond/Merge" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "T" + value { + type: DT_STRING + } + } +} +library { +} +versions { + producer: 27 +} + )EOF"; + + GrapplerItem item; + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph)); + item.fetch = {"Identity"}; + Tensor feed_tensor(DT_BOOL, {}); + feed_tensor.flat()(1) = false; + item.feed.push_back({"Const_1", feed_tensor}); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + ASSERT_EQ(tensors_expected.size(), 1); + + LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE, nullptr); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_CHECK_OK(status); + auto tensors_got = EvaluateNodes(output, item.fetch); + ASSERT_EQ(tensors_got.size(), 1); + test::ExpectTensorEqual(tensors_got[0], tensors_expected[0]); + + EXPECT_EQ(output.node_size(), 8); + + // No rewrite because branch has a constant feed node. + bool found = false; + for (const NodeDef& node : output.node()) { + if (node.name() == "cond/Merge") { + EXPECT_EQ(node.op(), "Merge"); + ASSERT_EQ(node.input_size(), 2); + EXPECT_EQ(node.input(0), "cond/Switch_1"); + EXPECT_EQ(node.input(1), "cond/Const"); + found = true; + break; + } + } + EXPECT_TRUE(found); } } // namespace grappler