[Grappler] Add helper methods for controlling fanin deduping and adding of controlling fanins for Switch ops in MutableGraphView.

PiperOrigin-RevId: 226194053
This commit is contained in:
Andy Ly 2018-12-19 10:26:41 -08:00 committed by TensorFlower Gardener
parent 264ce77f84
commit 6744f3c0fe
4 changed files with 381 additions and 21 deletions

View File

@ -196,6 +196,7 @@ tf_cc_test(
":utils",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -178,27 +179,21 @@ void MutableGraphView::UpdateFanouts(NodeDef* from_node, NodeDef* to_node) {
}
}
bool MutableGraphView::AddFanin(NodeDef* node, const TensorId& fanin) {
NodeDef* fanin_node = GetNode(fanin.node());
if (fanin_node == nullptr) {
return false;
}
bool MutableGraphView::AddFaninInternal(NodeDef* node,
const OutputPort& fanin) {
int num_non_controlling_fanins =
NumFanins(*node, /*include_controlling_nodes=*/false);
InputPort input;
input.node = node;
input.port_id = fanin.index() == Graph::kControlSlot
input.port_id = fanin.port_id == Graph::kControlSlot
? Graph::kControlSlot
: num_non_controlling_fanins;
OutputPort fanin_port(fanin_node, fanin.index());
if (!gtl::InsertIfNotPresent(&fanouts()[fanin_port], input)) {
if (!gtl::InsertIfNotPresent(&fanouts()[fanin], input)) {
return false;
}
node->add_input(TensorIdToString(fanin));
if (fanin.index() > Graph::kControlSlot) {
node->add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
if (fanin.port_id > Graph::kControlSlot) {
int node_input_size = node->input_size() - 1;
// If there are control dependencies in node, move newly inserted fanin to
// be before such control dependencies.
@ -210,6 +205,14 @@ bool MutableGraphView::AddFanin(NodeDef* node, const TensorId& fanin) {
return true;
}
bool MutableGraphView::AddFaninInternal(NodeDef* node, const TensorId& fanin) {
NodeDef* fanin_node = GetNode(fanin.node());
if (fanin_node == nullptr) {
return false;
}
return AddFaninInternal(node, {fanin_node, fanin.index()});
}
bool MutableGraphView::AddFanin(absl::string_view node_name,
const TensorId& fanin) {
if (!IsTensorIdPortValid(fanin)) {
@ -219,7 +222,7 @@ bool MutableGraphView::AddFanin(absl::string_view node_name,
if (node == nullptr) {
return false;
}
return AddFanin(node, fanin);
return AddFaninInternal(node, fanin);
}
bool MutableGraphView::RemoveFanins(NodeDef* node,
@ -318,7 +321,7 @@ bool MutableGraphView::UpdateFanin(absl::string_view node_name,
if (is_from_fanin_control || is_to_fanin_control) {
bool modified = RemoveFanins(node, {from_fanin});
if (!HasFanin(*node, to_fanin)) {
modified |= AddFanin(node, to_fanin);
modified |= AddFaninInternal(node, to_fanin);
}
return modified;
}
@ -357,6 +360,109 @@ bool MutableGraphView::UpdateFanin(absl::string_view node_name,
return modified;
}
bool MutableGraphView::DedupControllingFanins(NodeDef* node) {
absl::flat_hash_set<absl::string_view> fanins;
absl::flat_hash_set<string> removed_fanins;
int pos = 0;
const int last_idx = node->input_size() - 1;
int last_pos = last_idx;
while (pos <= last_pos) {
const string& input = node->input(pos);
TensorId tensor_id = ParseTensorName(input);
if (!gtl::InsertIfNotPresent(&fanins, tensor_id.node()) &&
IsControlInput(tensor_id)) {
node->mutable_input()->SwapElements(pos, last_pos--);
removed_fanins.insert(input);
} else {
++pos;
}
}
if (last_pos < last_idx) {
absl::flat_hash_set<string> retained_fanins(
node->input().begin(), node->input().begin() + last_pos + 1);
for (const auto& removed : removed_fanins) {
if (!retained_fanins.contains(removed)) {
OutputPort fanin(nodes()[ParseTensorName(removed).node()],
Graph::kControlSlot);
fanouts()[fanin].erase({node, Graph::kControlSlot});
}
}
node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
return true;
}
return false;
}
bool MutableGraphView::DedupControllingFanins(absl::string_view node_name) {
NodeDef* node = GetNode(node_name);
if (node == nullptr) {
return false;
}
return DedupControllingFanins(node);
}
bool MutableGraphView::DedupControllingFanins() {
const int num_nodes = graph()->node_size();
bool modified = false;
for (int i = 0; i < num_nodes; ++i) {
modified |= DedupControllingFanins(graph()->mutable_node(i));
}
return modified;
}
bool MutableGraphView::AddControllingFanin(absl::string_view node_name,
const TensorId& fanin) {
NodeDef* node = GetNode(node_name);
if (node == nullptr) {
return false;
}
NodeDef* fanin_node = GetNode(fanin.node());
if (fanin_node == nullptr) {
return false;
}
if (fanin.index() == Graph::kControlSlot) {
return AddFaninInternal(node, {fanin_node, Graph::kControlSlot});
}
if (!IsSwitch(*fanin_node)) {
return AddFaninInternal(node, {fanin_node, Graph::kControlSlot});
} else {
// We can't anchor control dependencies directly on the switch node: unlike
// other nodes only one of the outputs of the switch node will be generated
// when the switch node is executed, and we need to make sure the control
// dependency is only triggered when the corresponding output is triggered.
// We start by looking for an identity node connected to the output of the
// switch node, and use it to anchor the control dependency.
auto fanouts = GetFanouts(*fanin_node, /*include_controlled_nodes=*/false);
for (auto fanout : fanouts) {
if (IsIdentity(*fanout.node) || IsIdentityNSingleInput(*fanout.node)) {
if (ParseTensorName(fanout.node->input(0)) == fanin) {
return AddFaninInternal(node, {fanout.node, Graph::kControlSlot});
}
}
}
// We haven't found an existing node where we can anchor the control
// dependency: add a new identity node.
string ctrl_dep_name = AddPrefixToNodeName(
absl::StrCat(fanin.node(), "_", fanin.index()), kMutableGraphViewCtrl);
NodeDef* ctrl_dep_node = GetNode(ctrl_dep_name);
if (ctrl_dep_node == nullptr) {
NodeDef new_node;
new_node.set_name(ctrl_dep_name);
new_node.set_op("Identity");
new_node.set_device(fanin_node->device());
(*new_node.mutable_attr())["T"].set_type(
fanin_node->attr().at("T").type());
new_node.add_input(TensorIdToString(fanin));
ctrl_dep_node = AddNode(std::move(new_node));
}
return AddFaninInternal(node, {ctrl_dep_node, Graph::kControlSlot});
}
}
void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) {
for (const string& node_name_to_delete : nodes_to_delete)
RemoveFaninsInternal(nodes().at(node_name_to_delete),

View File

@ -31,6 +31,8 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
const char kMutableGraphViewCtrl[] = "ConstantFoldingCtrl";
// A utility class to simplify the traversal of a GraphDef that, unlike
// GraphView, supports updating the graph. Note that you should not modify the
// graph separately, because the view will get out of sync.
@ -102,6 +104,38 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
bool UpdateFanin(absl::string_view node_name, const TensorId& from_fanin,
const TensorId& to_fanin);
// Removes redundant control fanins from node `node_name`.
//
// This will return true iff the node is modified.
// TODO(lyandy): Measure performance of deduping on every AddFanin compared to
// deduping once at the end.
bool DedupControllingFanins(absl::string_view node_name);
// Removes redundant control fanins from all nodes in the graph.
//
// This will return true iff the node is modified.
bool DedupControllingFanins();
// Adds a control dependency to the target node named `node_name`.
//
// Case 1: If the fanin is not a Switch node, the control dependency is simply
// added to the target node:
//
// fanin -^> target node.
//
// Case 2: If the fanin is a Switch node, we cannot anchor a control
// dependency on it, because unlike other nodes, only one of its outputs will
// be generated when the node is activated. In this case, we try to find an
// Identity/IdentityN node in the fanout of the relevant port of the Switch
// and add it as a fanin to the target node. If no such Identity/IdentityN
// node can be found, a new Identity node will be created. In both cases, we
// end up with:
//
// fanin -> Identity{N} -^> target node.
//
// This will return true iff the node is modified.
bool AddControllingFanin(absl::string_view node_name, const TensorId& fanin);
// Deletes nodes from the graph.
void DeleteNodes(const std::set<string>& nodes_to_delete);
@ -121,11 +155,19 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
// behavior is undefined.
void UpdateFanouts(NodeDef* from_node, NodeDef* to_node);
// Remove fanins of the deleted node from internal state. Control dependencies
// are retained iff keep_controlling_fanins is true.
// Removes fanins of the deleted node from internal state. Control
// dependencies are retained iff keep_controlling_fanins is true.
void RemoveFaninsInternal(NodeDef* deleted_node,
bool keep_controlling_fanins);
// Add fanin to node. If fanin is a control dependency, existing control
// dependencies will be checked first before adding. Otherwise fanin will be
// added after existing non control dependency inputs.
//
// This will return true iff the node is modified. If a control dependency
// already exists, the node will not be modified.
bool AddFaninInternal(NodeDef* node, const OutputPort& fanin);
// Add fanin to node. If the node or fanin do not exist in the graph, nothing
// will be modified in the graph. If fanin is a control dependency, existing
// control dependencies will be checked first before adding. Otherwise fanin
@ -133,10 +175,13 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
//
// This will return true iff the node is modified. If a control dependency
// already exists, the node will not be modified.
bool AddFanin(NodeDef* node, const TensorId& fanin);
bool AddFaninInternal(NodeDef* node, const TensorId& fanin);
// Remove any fanin in node that matches to a fanin in fanins.
// Removes any fanin in node that matches to a fanin in fanins.
bool RemoveFanins(NodeDef* node, absl::Span<const TensorId> fanins);
// Removes redundant control fanins from node.
bool DedupControllingFanins(NodeDef* node);
};
} // end namespace grappler

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@ -35,7 +36,7 @@ TEST(MutableGraphViewTest, AddAndUpdateFanouts) {
NDef("other", "NotImportant", {}, {}),
NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}),
NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})},
/* empty function library */ {});
/*funcs=*/{});
MutableGraphView graph(&graph_def);
@ -78,7 +79,7 @@ TEST(MutableGraphViewTest, AddAndUpdateFanoutsWithoutSelfLoops) {
GraphDef graph_def =
test::function::GDef({NDef("bar", "NotImportant", {}, {}),
NDef("foo", "NotImportant", {"bar", "^bar"})},
/* empty function library */ {});
/*funcs=*/{});
MutableGraphView graph(&graph_def);
@ -462,6 +463,213 @@ TEST(MutableGraphViewTest, UpdateFanin) {
/*modified=*/false, /*expected_node=*/nullptr);
}
GraphDef SimpleDuplicateControllingFaninsGraph() {
// Actual node.op() is not important in this test.
GraphDef graph_def = test::function::GDef(
{NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
NDef("foo_1", "NotImportant", {"a", "b:1", "^b"}),
NDef("foo_2", "NotImportant", {"a", "^b", "^b"}),
NDef("foo_3", "NotImportant", {"a", "b:1", "^b", "^b"}),
NDef("foo_4", "NotImportant", {"a:2", "b:1", "^b", "^b", "^a", "^a"})},
/*funcs=*/{});
return graph_def;
}
void CheckDedupControllingFaninsForNode(MutableGraphView* graph,
absl::string_view node_name,
const NodeDef* expected_node) {
// Deduping again should result in no change.
EXPECT_FALSE(graph->DedupControllingFanins(node_name));
NodeDef* node = graph->GetNode(node_name);
ASSERT_NE(node, nullptr);
ASSERT_EQ(node->input_size(), expected_node->input_size());
CompareNodeInputs(*graph, expected_node, node);
for (int i = 0; i < node->input_size(); ++i) {
TensorId tensor_id = ParseTensorName(node->input(i));
if (tensor_id.index() > Graph::kControlSlot) {
CheckFanout(*graph, {tensor_id.node(), Graph::kControlSlot}, node_name);
}
}
}
void TestDedupControllingFaninsForNode(MutableGraphView* graph,
absl::string_view node_name,
const NodeDef* expected_node) {
EXPECT_TRUE(graph->DedupControllingFanins(node_name));
CheckDedupControllingFaninsForNode(graph, node_name, expected_node);
}
TEST(MutableGraphViewTest, DedupControllingFaninsForNode) {
GraphDef graph_def = SimpleDuplicateControllingFaninsGraph();
MutableGraphView graph(&graph_def);
NodeDef expected_node;
// Remove redundant control dependency '^b'.
expected_node = NDef("", "", {"a", "b:1"});
TestDedupControllingFaninsForNode(&graph, "foo_1", &expected_node);
// Remove extra control dependency '^b'.
expected_node = NDef("", "", {"a", "^b"});
TestDedupControllingFaninsForNode(&graph, "foo_2", &expected_node);
// Remove redundant and extra control dependencies '^b'.
expected_node = NDef("", "", {"a", "b:1"});
TestDedupControllingFaninsForNode(&graph, "foo_3", &expected_node);
// Remove multiple redundant control dependencies.
expected_node = NDef("", "", {"a:2", "b:1"});
TestDedupControllingFaninsForNode(&graph, "foo_4", &expected_node);
// Missing node.
EXPECT_FALSE(graph.DedupControllingFanins("missing"));
}
TEST(MutableGraphViewTest, DedupControllingFaninsForGraph) {
GraphDef graph_def = SimpleDuplicateControllingFaninsGraph();
MutableGraphView graph(&graph_def);
EXPECT_TRUE(graph.DedupControllingFanins());
// Deduping again should result in no change.
EXPECT_FALSE(graph.DedupControllingFanins());
NodeDef expected_node;
// Remove redundant control dependency '^b'.
expected_node = NDef("", "", {"a", "b:1"});
CheckDedupControllingFaninsForNode(&graph, "foo_1", &expected_node);
// Remove extra control dependency '^b'.
expected_node = NDef("", "", {"a", "^b"});
CheckDedupControllingFaninsForNode(&graph, "foo_2", &expected_node);
// Remove redundant and extra control dependencies '^b'.
expected_node = NDef("", "", {"a", "b:1"});
CheckDedupControllingFaninsForNode(&graph, "foo_3", &expected_node);
// Remove multiple redundant control dependencies.
expected_node = NDef("", "", {"a:2", "b:1"});
CheckDedupControllingFaninsForNode(&graph, "foo_4", &expected_node);
}
TEST(MutableGraphViewTest, AddControllingFaninMissing) {
// Actual node.op() is not important in this test.
GraphDef graph_def = test::function::GDef(
{NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})},
/*funcs=*/{});
MutableGraphView graph(&graph_def);
// Missing fanin.
EXPECT_FALSE(graph.AddControllingFanin("a", {"c", Graph::kControlSlot}));
// Missing node.
EXPECT_FALSE(graph.AddControllingFanin("d", {"a", Graph::kControlSlot}));
// Missing node and fanin.
EXPECT_FALSE(graph.AddControllingFanin("c", {"d", Graph::kControlSlot}));
ASSERT_EQ(graph.graph()->node_size(), 2);
NodeDef* a = graph.GetNode("a");
ASSERT_NE(a, nullptr);
ASSERT_EQ(a->input_size(), 0);
NodeDef* b = graph.GetNode("b");
ASSERT_NE(b, nullptr);
ASSERT_EQ(b->input_size(), 0);
}
TEST(MutableGraphViewTest, AddControllingFaninExistingControl) {
// Actual node.op() is not important in this test.
GraphDef graph_def = test::function::GDef(
{NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})},
/*funcs=*/{});
MutableGraphView graph(&graph_def);
EXPECT_TRUE(graph.AddControllingFanin("a", {"b", Graph::kControlSlot}));
EXPECT_FALSE(graph.AddControllingFanin("a", {"b", Graph::kControlSlot}));
ASSERT_EQ(graph.graph()->node_size(), 2);
NodeDef* a = graph.GetNode("a");
ASSERT_NE(a, nullptr);
ASSERT_EQ(a->input_size(), 1);
EXPECT_EQ(a->input(0), "^b");
NodeDef* b = graph.GetNode("b");
ASSERT_NE(b, nullptr);
ASSERT_EQ(b->input_size(), 0);
}
TEST(MutableGraphViewTest, AddControllingFaninNotSwitch) {
// Actual node.op() is not important in this test.
GraphDef graph_def = test::function::GDef(
{NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})},
/*funcs=*/{});
MutableGraphView graph(&graph_def);
EXPECT_TRUE(graph.AddControllingFanin("a", {"b", 2}));
EXPECT_FALSE(graph.AddControllingFanin("a", {"b", 2}));
ASSERT_EQ(graph.graph()->node_size(), 2);
NodeDef* a = graph.GetNode("a");
ASSERT_NE(a, nullptr);
ASSERT_EQ(a->input_size(), 1);
EXPECT_EQ(a->input(0), "^b");
NodeDef* b = graph.GetNode("b");
ASSERT_NE(b, nullptr);
ASSERT_EQ(b->input_size(), 0);
}
TEST(MutableGraphViewTest, AddControllingFaninSwitchWithIdentity) {
GraphDef graph_def = test::function::GDef(
{NDef("a", "NotImportant", {}, {}), NDef("switch", "Switch", {}, {}),
NDef("identity", "Identity", {"switch"})},
/*funcs=*/{});
MutableGraphView graph(&graph_def);
EXPECT_TRUE(graph.AddControllingFanin("a", {"switch", 0}));
EXPECT_FALSE(graph.AddControllingFanin("a", {"switch", 0}));
ASSERT_EQ(graph.graph()->node_size(), 3);
NodeDef* a = graph.GetNode("a");
ASSERT_NE(a, nullptr);
ASSERT_EQ(a->input_size(), 1);
EXPECT_EQ(a->input(0), "^identity");
}
TEST(MutableGraphViewTest, AddControllingFaninSwitchWithNoExistingIdentity) {
constexpr char kDevice[] = "/device:foo:0";
GraphDef graph_def = test::function::GDef(
{NDef("a", "NotImportant", {}, {}),
NDef("switch", "Switch", {}, {{"T", DT_FLOAT}}, kDevice)},
/*funcs=*/{});
MutableGraphView graph(&graph_def);
EXPECT_TRUE(graph.AddControllingFanin("a", {"switch", 0}));
EXPECT_FALSE(graph.AddControllingFanin("a", {"switch", 0}));
ASSERT_EQ(graph.graph()->node_size(), 3);
NodeDef* a = graph.GetNode("a");
ASSERT_NE(a, nullptr);
ASSERT_EQ(a->input_size(), 1);
EXPECT_EQ(a->input(0), "^ConstantFoldingCtrl/switch_0");
NodeDef* identity = graph.GetNode("ConstantFoldingCtrl/switch_0");
ASSERT_NE(identity, nullptr);
ASSERT_EQ(identity->input_size(), 1);
EXPECT_EQ(identity->input(0), "switch");
EXPECT_EQ(identity->op(), "Identity");
EXPECT_EQ(identity->device(), kDevice);
ASSERT_TRUE(identity->attr().count("T"));
EXPECT_EQ(identity->attr().at("T").type(), DT_FLOAT);
}
TEST(MutableGraphViewTest, AddControllingFaninSwitchWithExistingAddedIdentity) {
GraphDef graph_def = test::function::GDef(
{NDef("a", "NotImportant", {}, {}), NDef("switch", "Switch", {}, {}),
NDef("ConstantFoldingCtrl/switch_0", "Identity", {}, {})},
/*funcs=*/{});
MutableGraphView graph(&graph_def);
EXPECT_TRUE(graph.AddControllingFanin("a", {"switch", 0}));
EXPECT_FALSE(graph.AddControllingFanin("a", {"switch", 0}));
ASSERT_EQ(graph.graph()->node_size(), 3);
NodeDef* a = graph.GetNode("a");
ASSERT_NE(a, nullptr);
ASSERT_EQ(a->input_size(), 1);
EXPECT_EQ(a->input(0), "^ConstantFoldingCtrl/switch_0");
}
TEST(MutableGraphViewTest, DeleteNodes) {
// Actual node.op() is not important in this test.
GraphDef graph_def = test::function::GDef(
@ -469,7 +677,7 @@ TEST(MutableGraphViewTest, DeleteNodes) {
NDef("other", "NotImportant", {}, {}),
NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}),
NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})},
/* empty function library */ {});
/*funcs=*/{});
MutableGraphView graph(&graph_def);