From 6744f3c0fe5435df84e69d7165731d24dbc93e7b Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Wed, 19 Dec 2018 10:26:41 -0800 Subject: [PATCH] [Grappler] Add helper methods for controlling fanin deduping and adding of controlling fanins for Switch ops in MutableGraphView. PiperOrigin-RevId: 226194053 --- tensorflow/core/grappler/BUILD | 1 + .../core/grappler/mutable_graph_view.cc | 134 +++++++++-- tensorflow/core/grappler/mutable_graph_view.h | 53 ++++- .../core/grappler/mutable_graph_view_test.cc | 214 +++++++++++++++++- 4 files changed, 381 insertions(+), 21 deletions(-) diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index 6e3012000fc..6de12192ba8 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -196,6 +196,7 @@ tf_cc_test( ":utils", "//tensorflow/cc:cc_ops", "//tensorflow/core:graph", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", diff --git a/tensorflow/core/grappler/mutable_graph_view.cc b/tensorflow/core/grappler/mutable_graph_view.cc index 224b720328a..ca4d5255c0f 100644 --- a/tensorflow/core/grappler/mutable_graph_view.cc +++ b/tensorflow/core/grappler/mutable_graph_view.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" #include "absl/strings/substitute.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -178,27 +179,21 @@ void MutableGraphView::UpdateFanouts(NodeDef* from_node, NodeDef* to_node) { } } -bool MutableGraphView::AddFanin(NodeDef* node, const TensorId& fanin) { - NodeDef* fanin_node = GetNode(fanin.node()); - if (fanin_node == nullptr) { - return false; - } - +bool MutableGraphView::AddFaninInternal(NodeDef* node, + const OutputPort& fanin) { int num_non_controlling_fanins = NumFanins(*node, /*include_controlling_nodes=*/false); InputPort input; input.node = node; - input.port_id = fanin.index() == Graph::kControlSlot + input.port_id = fanin.port_id == Graph::kControlSlot ? Graph::kControlSlot : num_non_controlling_fanins; - OutputPort fanin_port(fanin_node, fanin.index()); - - if (!gtl::InsertIfNotPresent(&fanouts()[fanin_port], input)) { + if (!gtl::InsertIfNotPresent(&fanouts()[fanin], input)) { return false; } - node->add_input(TensorIdToString(fanin)); - if (fanin.index() > Graph::kControlSlot) { + node->add_input(TensorIdToString({fanin.node->name(), fanin.port_id})); + if (fanin.port_id > Graph::kControlSlot) { int node_input_size = node->input_size() - 1; // If there are control dependencies in node, move newly inserted fanin to // be before such control dependencies. @@ -210,6 +205,14 @@ bool MutableGraphView::AddFanin(NodeDef* node, const TensorId& fanin) { return true; } +bool MutableGraphView::AddFaninInternal(NodeDef* node, const TensorId& fanin) { + NodeDef* fanin_node = GetNode(fanin.node()); + if (fanin_node == nullptr) { + return false; + } + return AddFaninInternal(node, {fanin_node, fanin.index()}); +} + bool MutableGraphView::AddFanin(absl::string_view node_name, const TensorId& fanin) { if (!IsTensorIdPortValid(fanin)) { @@ -219,7 +222,7 @@ bool MutableGraphView::AddFanin(absl::string_view node_name, if (node == nullptr) { return false; } - return AddFanin(node, fanin); + return AddFaninInternal(node, fanin); } bool MutableGraphView::RemoveFanins(NodeDef* node, @@ -318,7 +321,7 @@ bool MutableGraphView::UpdateFanin(absl::string_view node_name, if (is_from_fanin_control || is_to_fanin_control) { bool modified = RemoveFanins(node, {from_fanin}); if (!HasFanin(*node, to_fanin)) { - modified |= AddFanin(node, to_fanin); + modified |= AddFaninInternal(node, to_fanin); } return modified; } @@ -357,6 +360,109 @@ bool MutableGraphView::UpdateFanin(absl::string_view node_name, return modified; } +bool MutableGraphView::DedupControllingFanins(NodeDef* node) { + absl::flat_hash_set fanins; + absl::flat_hash_set removed_fanins; + int pos = 0; + const int last_idx = node->input_size() - 1; + int last_pos = last_idx; + while (pos <= last_pos) { + const string& input = node->input(pos); + TensorId tensor_id = ParseTensorName(input); + if (!gtl::InsertIfNotPresent(&fanins, tensor_id.node()) && + IsControlInput(tensor_id)) { + node->mutable_input()->SwapElements(pos, last_pos--); + removed_fanins.insert(input); + } else { + ++pos; + } + } + + if (last_pos < last_idx) { + absl::flat_hash_set retained_fanins( + node->input().begin(), node->input().begin() + last_pos + 1); + for (const auto& removed : removed_fanins) { + if (!retained_fanins.contains(removed)) { + OutputPort fanin(nodes()[ParseTensorName(removed).node()], + Graph::kControlSlot); + fanouts()[fanin].erase({node, Graph::kControlSlot}); + } + } + node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos); + return true; + } + + return false; +} + +bool MutableGraphView::DedupControllingFanins(absl::string_view node_name) { + NodeDef* node = GetNode(node_name); + if (node == nullptr) { + return false; + } + return DedupControllingFanins(node); +} + +bool MutableGraphView::DedupControllingFanins() { + const int num_nodes = graph()->node_size(); + bool modified = false; + for (int i = 0; i < num_nodes; ++i) { + modified |= DedupControllingFanins(graph()->mutable_node(i)); + } + return modified; +} + +bool MutableGraphView::AddControllingFanin(absl::string_view node_name, + const TensorId& fanin) { + NodeDef* node = GetNode(node_name); + if (node == nullptr) { + return false; + } + NodeDef* fanin_node = GetNode(fanin.node()); + if (fanin_node == nullptr) { + return false; + } + if (fanin.index() == Graph::kControlSlot) { + return AddFaninInternal(node, {fanin_node, Graph::kControlSlot}); + } + + if (!IsSwitch(*fanin_node)) { + return AddFaninInternal(node, {fanin_node, Graph::kControlSlot}); + } else { + // We can't anchor control dependencies directly on the switch node: unlike + // other nodes only one of the outputs of the switch node will be generated + // when the switch node is executed, and we need to make sure the control + // dependency is only triggered when the corresponding output is triggered. + // We start by looking for an identity node connected to the output of the + // switch node, and use it to anchor the control dependency. + auto fanouts = GetFanouts(*fanin_node, /*include_controlled_nodes=*/false); + for (auto fanout : fanouts) { + if (IsIdentity(*fanout.node) || IsIdentityNSingleInput(*fanout.node)) { + if (ParseTensorName(fanout.node->input(0)) == fanin) { + return AddFaninInternal(node, {fanout.node, Graph::kControlSlot}); + } + } + } + // We haven't found an existing node where we can anchor the control + // dependency: add a new identity node. + string ctrl_dep_name = AddPrefixToNodeName( + absl::StrCat(fanin.node(), "_", fanin.index()), kMutableGraphViewCtrl); + + NodeDef* ctrl_dep_node = GetNode(ctrl_dep_name); + if (ctrl_dep_node == nullptr) { + NodeDef new_node; + new_node.set_name(ctrl_dep_name); + new_node.set_op("Identity"); + new_node.set_device(fanin_node->device()); + (*new_node.mutable_attr())["T"].set_type( + fanin_node->attr().at("T").type()); + new_node.add_input(TensorIdToString(fanin)); + ctrl_dep_node = AddNode(std::move(new_node)); + } + return AddFaninInternal(node, {ctrl_dep_node, Graph::kControlSlot}); + } +} + void MutableGraphView::DeleteNodes(const std::set& nodes_to_delete) { for (const string& node_name_to_delete : nodes_to_delete) RemoveFaninsInternal(nodes().at(node_name_to_delete), diff --git a/tensorflow/core/grappler/mutable_graph_view.h b/tensorflow/core/grappler/mutable_graph_view.h index 8025b8ca778..f7c2a1118e5 100644 --- a/tensorflow/core/grappler/mutable_graph_view.h +++ b/tensorflow/core/grappler/mutable_graph_view.h @@ -31,6 +31,8 @@ limitations under the License. namespace tensorflow { namespace grappler { +const char kMutableGraphViewCtrl[] = "ConstantFoldingCtrl"; + // A utility class to simplify the traversal of a GraphDef that, unlike // GraphView, supports updating the graph. Note that you should not modify the // graph separately, because the view will get out of sync. @@ -102,6 +104,38 @@ class MutableGraphView : public internal::GraphViewInternal { bool UpdateFanin(absl::string_view node_name, const TensorId& from_fanin, const TensorId& to_fanin); + // Removes redundant control fanins from node `node_name`. + // + // This will return true iff the node is modified. + // TODO(lyandy): Measure performance of deduping on every AddFanin compared to + // deduping once at the end. + bool DedupControllingFanins(absl::string_view node_name); + + // Removes redundant control fanins from all nodes in the graph. + // + // This will return true iff the node is modified. + bool DedupControllingFanins(); + + // Adds a control dependency to the target node named `node_name`. + // + // Case 1: If the fanin is not a Switch node, the control dependency is simply + // added to the target node: + // + // fanin -^> target node. + // + // Case 2: If the fanin is a Switch node, we cannot anchor a control + // dependency on it, because unlike other nodes, only one of its outputs will + // be generated when the node is activated. In this case, we try to find an + // Identity/IdentityN node in the fanout of the relevant port of the Switch + // and add it as a fanin to the target node. If no such Identity/IdentityN + // node can be found, a new Identity node will be created. In both cases, we + // end up with: + // + // fanin -> Identity{N} -^> target node. + // + // This will return true iff the node is modified. + bool AddControllingFanin(absl::string_view node_name, const TensorId& fanin); + // Deletes nodes from the graph. void DeleteNodes(const std::set& nodes_to_delete); @@ -121,11 +155,19 @@ class MutableGraphView : public internal::GraphViewInternal { // behavior is undefined. void UpdateFanouts(NodeDef* from_node, NodeDef* to_node); - // Remove fanins of the deleted node from internal state. Control dependencies - // are retained iff keep_controlling_fanins is true. + // Removes fanins of the deleted node from internal state. Control + // dependencies are retained iff keep_controlling_fanins is true. void RemoveFaninsInternal(NodeDef* deleted_node, bool keep_controlling_fanins); + // Add fanin to node. If fanin is a control dependency, existing control + // dependencies will be checked first before adding. Otherwise fanin will be + // added after existing non control dependency inputs. + // + // This will return true iff the node is modified. If a control dependency + // already exists, the node will not be modified. + bool AddFaninInternal(NodeDef* node, const OutputPort& fanin); + // Add fanin to node. If the node or fanin do not exist in the graph, nothing // will be modified in the graph. If fanin is a control dependency, existing // control dependencies will be checked first before adding. Otherwise fanin @@ -133,10 +175,13 @@ class MutableGraphView : public internal::GraphViewInternal { // // This will return true iff the node is modified. If a control dependency // already exists, the node will not be modified. - bool AddFanin(NodeDef* node, const TensorId& fanin); + bool AddFaninInternal(NodeDef* node, const TensorId& fanin); - // Remove any fanin in node that matches to a fanin in fanins. + // Removes any fanin in node that matches to a fanin in fanins. bool RemoveFanins(NodeDef* node, absl::Span fanins); + + // Removes redundant control fanins from node. + bool DedupControllingFanins(NodeDef* node); }; } // end namespace grappler diff --git a/tensorflow/core/grappler/mutable_graph_view_test.cc b/tensorflow/core/grappler/mutable_graph_view_test.cc index cd7e638595e..cdc212f6f9e 100644 --- a/tensorflow/core/grappler/mutable_graph_view_test.cc +++ b/tensorflow/core/grappler/mutable_graph_view_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" @@ -35,7 +36,7 @@ TEST(MutableGraphViewTest, AddAndUpdateFanouts) { NDef("other", "NotImportant", {}, {}), NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}), NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})}, - /* empty function library */ {}); + /*funcs=*/{}); MutableGraphView graph(&graph_def); @@ -78,7 +79,7 @@ TEST(MutableGraphViewTest, AddAndUpdateFanoutsWithoutSelfLoops) { GraphDef graph_def = test::function::GDef({NDef("bar", "NotImportant", {}, {}), NDef("foo", "NotImportant", {"bar", "^bar"})}, - /* empty function library */ {}); + /*funcs=*/{}); MutableGraphView graph(&graph_def); @@ -462,6 +463,213 @@ TEST(MutableGraphViewTest, UpdateFanin) { /*modified=*/false, /*expected_node=*/nullptr); } +GraphDef SimpleDuplicateControllingFaninsGraph() { + // Actual node.op() is not important in this test. + GraphDef graph_def = test::function::GDef( + {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}), + NDef("foo_1", "NotImportant", {"a", "b:1", "^b"}), + NDef("foo_2", "NotImportant", {"a", "^b", "^b"}), + NDef("foo_3", "NotImportant", {"a", "b:1", "^b", "^b"}), + NDef("foo_4", "NotImportant", {"a:2", "b:1", "^b", "^b", "^a", "^a"})}, + /*funcs=*/{}); + return graph_def; +} + +void CheckDedupControllingFaninsForNode(MutableGraphView* graph, + absl::string_view node_name, + const NodeDef* expected_node) { + // Deduping again should result in no change. + EXPECT_FALSE(graph->DedupControllingFanins(node_name)); + NodeDef* node = graph->GetNode(node_name); + ASSERT_NE(node, nullptr); + ASSERT_EQ(node->input_size(), expected_node->input_size()); + CompareNodeInputs(*graph, expected_node, node); + for (int i = 0; i < node->input_size(); ++i) { + TensorId tensor_id = ParseTensorName(node->input(i)); + if (tensor_id.index() > Graph::kControlSlot) { + CheckFanout(*graph, {tensor_id.node(), Graph::kControlSlot}, node_name); + } + } +} + +void TestDedupControllingFaninsForNode(MutableGraphView* graph, + absl::string_view node_name, + const NodeDef* expected_node) { + EXPECT_TRUE(graph->DedupControllingFanins(node_name)); + CheckDedupControllingFaninsForNode(graph, node_name, expected_node); +} + +TEST(MutableGraphViewTest, DedupControllingFaninsForNode) { + GraphDef graph_def = SimpleDuplicateControllingFaninsGraph(); + + MutableGraphView graph(&graph_def); + + NodeDef expected_node; + // Remove redundant control dependency '^b'. + expected_node = NDef("", "", {"a", "b:1"}); + TestDedupControllingFaninsForNode(&graph, "foo_1", &expected_node); + // Remove extra control dependency '^b'. + expected_node = NDef("", "", {"a", "^b"}); + TestDedupControllingFaninsForNode(&graph, "foo_2", &expected_node); + // Remove redundant and extra control dependencies '^b'. + expected_node = NDef("", "", {"a", "b:1"}); + TestDedupControllingFaninsForNode(&graph, "foo_3", &expected_node); + // Remove multiple redundant control dependencies. + expected_node = NDef("", "", {"a:2", "b:1"}); + TestDedupControllingFaninsForNode(&graph, "foo_4", &expected_node); + // Missing node. + EXPECT_FALSE(graph.DedupControllingFanins("missing")); +} + +TEST(MutableGraphViewTest, DedupControllingFaninsForGraph) { + GraphDef graph_def = SimpleDuplicateControllingFaninsGraph(); + + MutableGraphView graph(&graph_def); + EXPECT_TRUE(graph.DedupControllingFanins()); + // Deduping again should result in no change. + EXPECT_FALSE(graph.DedupControllingFanins()); + + NodeDef expected_node; + // Remove redundant control dependency '^b'. + expected_node = NDef("", "", {"a", "b:1"}); + CheckDedupControllingFaninsForNode(&graph, "foo_1", &expected_node); + // Remove extra control dependency '^b'. + expected_node = NDef("", "", {"a", "^b"}); + CheckDedupControllingFaninsForNode(&graph, "foo_2", &expected_node); + // Remove redundant and extra control dependencies '^b'. + expected_node = NDef("", "", {"a", "b:1"}); + CheckDedupControllingFaninsForNode(&graph, "foo_3", &expected_node); + // Remove multiple redundant control dependencies. + expected_node = NDef("", "", {"a:2", "b:1"}); + CheckDedupControllingFaninsForNode(&graph, "foo_4", &expected_node); +} + +TEST(MutableGraphViewTest, AddControllingFaninMissing) { + // Actual node.op() is not important in this test. + GraphDef graph_def = test::function::GDef( + {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})}, + /*funcs=*/{}); + + MutableGraphView graph(&graph_def); + // Missing fanin. + EXPECT_FALSE(graph.AddControllingFanin("a", {"c", Graph::kControlSlot})); + // Missing node. + EXPECT_FALSE(graph.AddControllingFanin("d", {"a", Graph::kControlSlot})); + // Missing node and fanin. + EXPECT_FALSE(graph.AddControllingFanin("c", {"d", Graph::kControlSlot})); + + ASSERT_EQ(graph.graph()->node_size(), 2); + NodeDef* a = graph.GetNode("a"); + ASSERT_NE(a, nullptr); + ASSERT_EQ(a->input_size(), 0); + NodeDef* b = graph.GetNode("b"); + ASSERT_NE(b, nullptr); + ASSERT_EQ(b->input_size(), 0); +} + +TEST(MutableGraphViewTest, AddControllingFaninExistingControl) { + // Actual node.op() is not important in this test. + GraphDef graph_def = test::function::GDef( + {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})}, + /*funcs=*/{}); + + MutableGraphView graph(&graph_def); + EXPECT_TRUE(graph.AddControllingFanin("a", {"b", Graph::kControlSlot})); + EXPECT_FALSE(graph.AddControllingFanin("a", {"b", Graph::kControlSlot})); + + ASSERT_EQ(graph.graph()->node_size(), 2); + NodeDef* a = graph.GetNode("a"); + ASSERT_NE(a, nullptr); + ASSERT_EQ(a->input_size(), 1); + EXPECT_EQ(a->input(0), "^b"); + NodeDef* b = graph.GetNode("b"); + ASSERT_NE(b, nullptr); + ASSERT_EQ(b->input_size(), 0); +} + +TEST(MutableGraphViewTest, AddControllingFaninNotSwitch) { + // Actual node.op() is not important in this test. + GraphDef graph_def = test::function::GDef( + {NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})}, + /*funcs=*/{}); + + MutableGraphView graph(&graph_def); + EXPECT_TRUE(graph.AddControllingFanin("a", {"b", 2})); + EXPECT_FALSE(graph.AddControllingFanin("a", {"b", 2})); + + ASSERT_EQ(graph.graph()->node_size(), 2); + NodeDef* a = graph.GetNode("a"); + ASSERT_NE(a, nullptr); + ASSERT_EQ(a->input_size(), 1); + EXPECT_EQ(a->input(0), "^b"); + NodeDef* b = graph.GetNode("b"); + ASSERT_NE(b, nullptr); + ASSERT_EQ(b->input_size(), 0); +} + +TEST(MutableGraphViewTest, AddControllingFaninSwitchWithIdentity) { + GraphDef graph_def = test::function::GDef( + {NDef("a", "NotImportant", {}, {}), NDef("switch", "Switch", {}, {}), + NDef("identity", "Identity", {"switch"})}, + /*funcs=*/{}); + + MutableGraphView graph(&graph_def); + + EXPECT_TRUE(graph.AddControllingFanin("a", {"switch", 0})); + EXPECT_FALSE(graph.AddControllingFanin("a", {"switch", 0})); + + ASSERT_EQ(graph.graph()->node_size(), 3); + NodeDef* a = graph.GetNode("a"); + ASSERT_NE(a, nullptr); + ASSERT_EQ(a->input_size(), 1); + EXPECT_EQ(a->input(0), "^identity"); +} + +TEST(MutableGraphViewTest, AddControllingFaninSwitchWithNoExistingIdentity) { + constexpr char kDevice[] = "/device:foo:0"; + GraphDef graph_def = test::function::GDef( + {NDef("a", "NotImportant", {}, {}), + NDef("switch", "Switch", {}, {{"T", DT_FLOAT}}, kDevice)}, + /*funcs=*/{}); + + MutableGraphView graph(&graph_def); + + EXPECT_TRUE(graph.AddControllingFanin("a", {"switch", 0})); + EXPECT_FALSE(graph.AddControllingFanin("a", {"switch", 0})); + + ASSERT_EQ(graph.graph()->node_size(), 3); + NodeDef* a = graph.GetNode("a"); + ASSERT_NE(a, nullptr); + ASSERT_EQ(a->input_size(), 1); + EXPECT_EQ(a->input(0), "^ConstantFoldingCtrl/switch_0"); + NodeDef* identity = graph.GetNode("ConstantFoldingCtrl/switch_0"); + ASSERT_NE(identity, nullptr); + ASSERT_EQ(identity->input_size(), 1); + EXPECT_EQ(identity->input(0), "switch"); + EXPECT_EQ(identity->op(), "Identity"); + EXPECT_EQ(identity->device(), kDevice); + ASSERT_TRUE(identity->attr().count("T")); + EXPECT_EQ(identity->attr().at("T").type(), DT_FLOAT); +} + +TEST(MutableGraphViewTest, AddControllingFaninSwitchWithExistingAddedIdentity) { + GraphDef graph_def = test::function::GDef( + {NDef("a", "NotImportant", {}, {}), NDef("switch", "Switch", {}, {}), + NDef("ConstantFoldingCtrl/switch_0", "Identity", {}, {})}, + /*funcs=*/{}); + + MutableGraphView graph(&graph_def); + + EXPECT_TRUE(graph.AddControllingFanin("a", {"switch", 0})); + EXPECT_FALSE(graph.AddControllingFanin("a", {"switch", 0})); + + ASSERT_EQ(graph.graph()->node_size(), 3); + NodeDef* a = graph.GetNode("a"); + ASSERT_NE(a, nullptr); + ASSERT_EQ(a->input_size(), 1); + EXPECT_EQ(a->input(0), "^ConstantFoldingCtrl/switch_0"); +} + TEST(MutableGraphViewTest, DeleteNodes) { // Actual node.op() is not important in this test. GraphDef graph_def = test::function::GDef( @@ -469,7 +677,7 @@ TEST(MutableGraphViewTest, DeleteNodes) { NDef("other", "NotImportant", {}, {}), NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}), NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})}, - /* empty function library */ {}); + /*funcs=*/{}); MutableGraphView graph(&graph_def);