[Grappler] Add helper methods for controlling fanin deduping and adding of controlling fanins for Switch ops in MutableGraphView.
PiperOrigin-RevId: 226194053
This commit is contained in:
parent
264ce77f84
commit
6744f3c0fe
@ -196,6 +196,7 @@ tf_cc_test(
|
||||
":utils",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -178,27 +179,21 @@ void MutableGraphView::UpdateFanouts(NodeDef* from_node, NodeDef* to_node) {
|
||||
}
|
||||
}
|
||||
|
||||
bool MutableGraphView::AddFanin(NodeDef* node, const TensorId& fanin) {
|
||||
NodeDef* fanin_node = GetNode(fanin.node());
|
||||
if (fanin_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MutableGraphView::AddFaninInternal(NodeDef* node,
|
||||
const OutputPort& fanin) {
|
||||
int num_non_controlling_fanins =
|
||||
NumFanins(*node, /*include_controlling_nodes=*/false);
|
||||
InputPort input;
|
||||
input.node = node;
|
||||
input.port_id = fanin.index() == Graph::kControlSlot
|
||||
input.port_id = fanin.port_id == Graph::kControlSlot
|
||||
? Graph::kControlSlot
|
||||
: num_non_controlling_fanins;
|
||||
|
||||
OutputPort fanin_port(fanin_node, fanin.index());
|
||||
|
||||
if (!gtl::InsertIfNotPresent(&fanouts()[fanin_port], input)) {
|
||||
if (!gtl::InsertIfNotPresent(&fanouts()[fanin], input)) {
|
||||
return false;
|
||||
}
|
||||
node->add_input(TensorIdToString(fanin));
|
||||
if (fanin.index() > Graph::kControlSlot) {
|
||||
node->add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
|
||||
if (fanin.port_id > Graph::kControlSlot) {
|
||||
int node_input_size = node->input_size() - 1;
|
||||
// If there are control dependencies in node, move newly inserted fanin to
|
||||
// be before such control dependencies.
|
||||
@ -210,6 +205,14 @@ bool MutableGraphView::AddFanin(NodeDef* node, const TensorId& fanin) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MutableGraphView::AddFaninInternal(NodeDef* node, const TensorId& fanin) {
|
||||
NodeDef* fanin_node = GetNode(fanin.node());
|
||||
if (fanin_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return AddFaninInternal(node, {fanin_node, fanin.index()});
|
||||
}
|
||||
|
||||
bool MutableGraphView::AddFanin(absl::string_view node_name,
|
||||
const TensorId& fanin) {
|
||||
if (!IsTensorIdPortValid(fanin)) {
|
||||
@ -219,7 +222,7 @@ bool MutableGraphView::AddFanin(absl::string_view node_name,
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return AddFanin(node, fanin);
|
||||
return AddFaninInternal(node, fanin);
|
||||
}
|
||||
|
||||
bool MutableGraphView::RemoveFanins(NodeDef* node,
|
||||
@ -318,7 +321,7 @@ bool MutableGraphView::UpdateFanin(absl::string_view node_name,
|
||||
if (is_from_fanin_control || is_to_fanin_control) {
|
||||
bool modified = RemoveFanins(node, {from_fanin});
|
||||
if (!HasFanin(*node, to_fanin)) {
|
||||
modified |= AddFanin(node, to_fanin);
|
||||
modified |= AddFaninInternal(node, to_fanin);
|
||||
}
|
||||
return modified;
|
||||
}
|
||||
@ -357,6 +360,109 @@ bool MutableGraphView::UpdateFanin(absl::string_view node_name,
|
||||
return modified;
|
||||
}
|
||||
|
||||
bool MutableGraphView::DedupControllingFanins(NodeDef* node) {
|
||||
absl::flat_hash_set<absl::string_view> fanins;
|
||||
absl::flat_hash_set<string> removed_fanins;
|
||||
int pos = 0;
|
||||
const int last_idx = node->input_size() - 1;
|
||||
int last_pos = last_idx;
|
||||
while (pos <= last_pos) {
|
||||
const string& input = node->input(pos);
|
||||
TensorId tensor_id = ParseTensorName(input);
|
||||
if (!gtl::InsertIfNotPresent(&fanins, tensor_id.node()) &&
|
||||
IsControlInput(tensor_id)) {
|
||||
node->mutable_input()->SwapElements(pos, last_pos--);
|
||||
removed_fanins.insert(input);
|
||||
} else {
|
||||
++pos;
|
||||
}
|
||||
}
|
||||
|
||||
if (last_pos < last_idx) {
|
||||
absl::flat_hash_set<string> retained_fanins(
|
||||
node->input().begin(), node->input().begin() + last_pos + 1);
|
||||
for (const auto& removed : removed_fanins) {
|
||||
if (!retained_fanins.contains(removed)) {
|
||||
OutputPort fanin(nodes()[ParseTensorName(removed).node()],
|
||||
Graph::kControlSlot);
|
||||
fanouts()[fanin].erase({node, Graph::kControlSlot});
|
||||
}
|
||||
}
|
||||
node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool MutableGraphView::DedupControllingFanins(absl::string_view node_name) {
|
||||
NodeDef* node = GetNode(node_name);
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return DedupControllingFanins(node);
|
||||
}
|
||||
|
||||
bool MutableGraphView::DedupControllingFanins() {
|
||||
const int num_nodes = graph()->node_size();
|
||||
bool modified = false;
|
||||
for (int i = 0; i < num_nodes; ++i) {
|
||||
modified |= DedupControllingFanins(graph()->mutable_node(i));
|
||||
}
|
||||
return modified;
|
||||
}
|
||||
|
||||
bool MutableGraphView::AddControllingFanin(absl::string_view node_name,
|
||||
const TensorId& fanin) {
|
||||
NodeDef* node = GetNode(node_name);
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
NodeDef* fanin_node = GetNode(fanin.node());
|
||||
if (fanin_node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (fanin.index() == Graph::kControlSlot) {
|
||||
return AddFaninInternal(node, {fanin_node, Graph::kControlSlot});
|
||||
}
|
||||
|
||||
if (!IsSwitch(*fanin_node)) {
|
||||
return AddFaninInternal(node, {fanin_node, Graph::kControlSlot});
|
||||
} else {
|
||||
// We can't anchor control dependencies directly on the switch node: unlike
|
||||
// other nodes only one of the outputs of the switch node will be generated
|
||||
// when the switch node is executed, and we need to make sure the control
|
||||
// dependency is only triggered when the corresponding output is triggered.
|
||||
// We start by looking for an identity node connected to the output of the
|
||||
// switch node, and use it to anchor the control dependency.
|
||||
auto fanouts = GetFanouts(*fanin_node, /*include_controlled_nodes=*/false);
|
||||
for (auto fanout : fanouts) {
|
||||
if (IsIdentity(*fanout.node) || IsIdentityNSingleInput(*fanout.node)) {
|
||||
if (ParseTensorName(fanout.node->input(0)) == fanin) {
|
||||
return AddFaninInternal(node, {fanout.node, Graph::kControlSlot});
|
||||
}
|
||||
}
|
||||
}
|
||||
// We haven't found an existing node where we can anchor the control
|
||||
// dependency: add a new identity node.
|
||||
string ctrl_dep_name = AddPrefixToNodeName(
|
||||
absl::StrCat(fanin.node(), "_", fanin.index()), kMutableGraphViewCtrl);
|
||||
|
||||
NodeDef* ctrl_dep_node = GetNode(ctrl_dep_name);
|
||||
if (ctrl_dep_node == nullptr) {
|
||||
NodeDef new_node;
|
||||
new_node.set_name(ctrl_dep_name);
|
||||
new_node.set_op("Identity");
|
||||
new_node.set_device(fanin_node->device());
|
||||
(*new_node.mutable_attr())["T"].set_type(
|
||||
fanin_node->attr().at("T").type());
|
||||
new_node.add_input(TensorIdToString(fanin));
|
||||
ctrl_dep_node = AddNode(std::move(new_node));
|
||||
}
|
||||
return AddFaninInternal(node, {ctrl_dep_node, Graph::kControlSlot});
|
||||
}
|
||||
}
|
||||
|
||||
void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) {
|
||||
for (const string& node_name_to_delete : nodes_to_delete)
|
||||
RemoveFaninsInternal(nodes().at(node_name_to_delete),
|
||||
|
@ -31,6 +31,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
const char kMutableGraphViewCtrl[] = "ConstantFoldingCtrl";
|
||||
|
||||
// A utility class to simplify the traversal of a GraphDef that, unlike
|
||||
// GraphView, supports updating the graph. Note that you should not modify the
|
||||
// graph separately, because the view will get out of sync.
|
||||
@ -102,6 +104,38 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
|
||||
bool UpdateFanin(absl::string_view node_name, const TensorId& from_fanin,
|
||||
const TensorId& to_fanin);
|
||||
|
||||
// Removes redundant control fanins from node `node_name`.
|
||||
//
|
||||
// This will return true iff the node is modified.
|
||||
// TODO(lyandy): Measure performance of deduping on every AddFanin compared to
|
||||
// deduping once at the end.
|
||||
bool DedupControllingFanins(absl::string_view node_name);
|
||||
|
||||
// Removes redundant control fanins from all nodes in the graph.
|
||||
//
|
||||
// This will return true iff the node is modified.
|
||||
bool DedupControllingFanins();
|
||||
|
||||
// Adds a control dependency to the target node named `node_name`.
|
||||
//
|
||||
// Case 1: If the fanin is not a Switch node, the control dependency is simply
|
||||
// added to the target node:
|
||||
//
|
||||
// fanin -^> target node.
|
||||
//
|
||||
// Case 2: If the fanin is a Switch node, we cannot anchor a control
|
||||
// dependency on it, because unlike other nodes, only one of its outputs will
|
||||
// be generated when the node is activated. In this case, we try to find an
|
||||
// Identity/IdentityN node in the fanout of the relevant port of the Switch
|
||||
// and add it as a fanin to the target node. If no such Identity/IdentityN
|
||||
// node can be found, a new Identity node will be created. In both cases, we
|
||||
// end up with:
|
||||
//
|
||||
// fanin -> Identity{N} -^> target node.
|
||||
//
|
||||
// This will return true iff the node is modified.
|
||||
bool AddControllingFanin(absl::string_view node_name, const TensorId& fanin);
|
||||
|
||||
// Deletes nodes from the graph.
|
||||
void DeleteNodes(const std::set<string>& nodes_to_delete);
|
||||
|
||||
@ -121,11 +155,19 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
|
||||
// behavior is undefined.
|
||||
void UpdateFanouts(NodeDef* from_node, NodeDef* to_node);
|
||||
|
||||
// Remove fanins of the deleted node from internal state. Control dependencies
|
||||
// are retained iff keep_controlling_fanins is true.
|
||||
// Removes fanins of the deleted node from internal state. Control
|
||||
// dependencies are retained iff keep_controlling_fanins is true.
|
||||
void RemoveFaninsInternal(NodeDef* deleted_node,
|
||||
bool keep_controlling_fanins);
|
||||
|
||||
// Add fanin to node. If fanin is a control dependency, existing control
|
||||
// dependencies will be checked first before adding. Otherwise fanin will be
|
||||
// added after existing non control dependency inputs.
|
||||
//
|
||||
// This will return true iff the node is modified. If a control dependency
|
||||
// already exists, the node will not be modified.
|
||||
bool AddFaninInternal(NodeDef* node, const OutputPort& fanin);
|
||||
|
||||
// Add fanin to node. If the node or fanin do not exist in the graph, nothing
|
||||
// will be modified in the graph. If fanin is a control dependency, existing
|
||||
// control dependencies will be checked first before adding. Otherwise fanin
|
||||
@ -133,10 +175,13 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
|
||||
//
|
||||
// This will return true iff the node is modified. If a control dependency
|
||||
// already exists, the node will not be modified.
|
||||
bool AddFanin(NodeDef* node, const TensorId& fanin);
|
||||
bool AddFaninInternal(NodeDef* node, const TensorId& fanin);
|
||||
|
||||
// Remove any fanin in node that matches to a fanin in fanins.
|
||||
// Removes any fanin in node that matches to a fanin in fanins.
|
||||
bool RemoveFanins(NodeDef* node, absl::Span<const TensorId> fanins);
|
||||
|
||||
// Removes redundant control fanins from node.
|
||||
bool DedupControllingFanins(NodeDef* node);
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
@ -35,7 +36,7 @@ TEST(MutableGraphViewTest, AddAndUpdateFanouts) {
|
||||
NDef("other", "NotImportant", {}, {}),
|
||||
NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}),
|
||||
NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})},
|
||||
/* empty function library */ {});
|
||||
/*funcs=*/{});
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
@ -78,7 +79,7 @@ TEST(MutableGraphViewTest, AddAndUpdateFanoutsWithoutSelfLoops) {
|
||||
GraphDef graph_def =
|
||||
test::function::GDef({NDef("bar", "NotImportant", {}, {}),
|
||||
NDef("foo", "NotImportant", {"bar", "^bar"})},
|
||||
/* empty function library */ {});
|
||||
/*funcs=*/{});
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
@ -462,6 +463,213 @@ TEST(MutableGraphViewTest, UpdateFanin) {
|
||||
/*modified=*/false, /*expected_node=*/nullptr);
|
||||
}
|
||||
|
||||
GraphDef SimpleDuplicateControllingFaninsGraph() {
|
||||
// Actual node.op() is not important in this test.
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
{NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
|
||||
NDef("foo_1", "NotImportant", {"a", "b:1", "^b"}),
|
||||
NDef("foo_2", "NotImportant", {"a", "^b", "^b"}),
|
||||
NDef("foo_3", "NotImportant", {"a", "b:1", "^b", "^b"}),
|
||||
NDef("foo_4", "NotImportant", {"a:2", "b:1", "^b", "^b", "^a", "^a"})},
|
||||
/*funcs=*/{});
|
||||
return graph_def;
|
||||
}
|
||||
|
||||
void CheckDedupControllingFaninsForNode(MutableGraphView* graph,
|
||||
absl::string_view node_name,
|
||||
const NodeDef* expected_node) {
|
||||
// Deduping again should result in no change.
|
||||
EXPECT_FALSE(graph->DedupControllingFanins(node_name));
|
||||
NodeDef* node = graph->GetNode(node_name);
|
||||
ASSERT_NE(node, nullptr);
|
||||
ASSERT_EQ(node->input_size(), expected_node->input_size());
|
||||
CompareNodeInputs(*graph, expected_node, node);
|
||||
for (int i = 0; i < node->input_size(); ++i) {
|
||||
TensorId tensor_id = ParseTensorName(node->input(i));
|
||||
if (tensor_id.index() > Graph::kControlSlot) {
|
||||
CheckFanout(*graph, {tensor_id.node(), Graph::kControlSlot}, node_name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TestDedupControllingFaninsForNode(MutableGraphView* graph,
|
||||
absl::string_view node_name,
|
||||
const NodeDef* expected_node) {
|
||||
EXPECT_TRUE(graph->DedupControllingFanins(node_name));
|
||||
CheckDedupControllingFaninsForNode(graph, node_name, expected_node);
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, DedupControllingFaninsForNode) {
|
||||
GraphDef graph_def = SimpleDuplicateControllingFaninsGraph();
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
NodeDef expected_node;
|
||||
// Remove redundant control dependency '^b'.
|
||||
expected_node = NDef("", "", {"a", "b:1"});
|
||||
TestDedupControllingFaninsForNode(&graph, "foo_1", &expected_node);
|
||||
// Remove extra control dependency '^b'.
|
||||
expected_node = NDef("", "", {"a", "^b"});
|
||||
TestDedupControllingFaninsForNode(&graph, "foo_2", &expected_node);
|
||||
// Remove redundant and extra control dependencies '^b'.
|
||||
expected_node = NDef("", "", {"a", "b:1"});
|
||||
TestDedupControllingFaninsForNode(&graph, "foo_3", &expected_node);
|
||||
// Remove multiple redundant control dependencies.
|
||||
expected_node = NDef("", "", {"a:2", "b:1"});
|
||||
TestDedupControllingFaninsForNode(&graph, "foo_4", &expected_node);
|
||||
// Missing node.
|
||||
EXPECT_FALSE(graph.DedupControllingFanins("missing"));
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, DedupControllingFaninsForGraph) {
|
||||
GraphDef graph_def = SimpleDuplicateControllingFaninsGraph();
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
EXPECT_TRUE(graph.DedupControllingFanins());
|
||||
// Deduping again should result in no change.
|
||||
EXPECT_FALSE(graph.DedupControllingFanins());
|
||||
|
||||
NodeDef expected_node;
|
||||
// Remove redundant control dependency '^b'.
|
||||
expected_node = NDef("", "", {"a", "b:1"});
|
||||
CheckDedupControllingFaninsForNode(&graph, "foo_1", &expected_node);
|
||||
// Remove extra control dependency '^b'.
|
||||
expected_node = NDef("", "", {"a", "^b"});
|
||||
CheckDedupControllingFaninsForNode(&graph, "foo_2", &expected_node);
|
||||
// Remove redundant and extra control dependencies '^b'.
|
||||
expected_node = NDef("", "", {"a", "b:1"});
|
||||
CheckDedupControllingFaninsForNode(&graph, "foo_3", &expected_node);
|
||||
// Remove multiple redundant control dependencies.
|
||||
expected_node = NDef("", "", {"a:2", "b:1"});
|
||||
CheckDedupControllingFaninsForNode(&graph, "foo_4", &expected_node);
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, AddControllingFaninMissing) {
|
||||
// Actual node.op() is not important in this test.
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
{NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})},
|
||||
/*funcs=*/{});
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
// Missing fanin.
|
||||
EXPECT_FALSE(graph.AddControllingFanin("a", {"c", Graph::kControlSlot}));
|
||||
// Missing node.
|
||||
EXPECT_FALSE(graph.AddControllingFanin("d", {"a", Graph::kControlSlot}));
|
||||
// Missing node and fanin.
|
||||
EXPECT_FALSE(graph.AddControllingFanin("c", {"d", Graph::kControlSlot}));
|
||||
|
||||
ASSERT_EQ(graph.graph()->node_size(), 2);
|
||||
NodeDef* a = graph.GetNode("a");
|
||||
ASSERT_NE(a, nullptr);
|
||||
ASSERT_EQ(a->input_size(), 0);
|
||||
NodeDef* b = graph.GetNode("b");
|
||||
ASSERT_NE(b, nullptr);
|
||||
ASSERT_EQ(b->input_size(), 0);
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, AddControllingFaninExistingControl) {
|
||||
// Actual node.op() is not important in this test.
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
{NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})},
|
||||
/*funcs=*/{});
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
EXPECT_TRUE(graph.AddControllingFanin("a", {"b", Graph::kControlSlot}));
|
||||
EXPECT_FALSE(graph.AddControllingFanin("a", {"b", Graph::kControlSlot}));
|
||||
|
||||
ASSERT_EQ(graph.graph()->node_size(), 2);
|
||||
NodeDef* a = graph.GetNode("a");
|
||||
ASSERT_NE(a, nullptr);
|
||||
ASSERT_EQ(a->input_size(), 1);
|
||||
EXPECT_EQ(a->input(0), "^b");
|
||||
NodeDef* b = graph.GetNode("b");
|
||||
ASSERT_NE(b, nullptr);
|
||||
ASSERT_EQ(b->input_size(), 0);
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, AddControllingFaninNotSwitch) {
|
||||
// Actual node.op() is not important in this test.
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
{NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {})},
|
||||
/*funcs=*/{});
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
EXPECT_TRUE(graph.AddControllingFanin("a", {"b", 2}));
|
||||
EXPECT_FALSE(graph.AddControllingFanin("a", {"b", 2}));
|
||||
|
||||
ASSERT_EQ(graph.graph()->node_size(), 2);
|
||||
NodeDef* a = graph.GetNode("a");
|
||||
ASSERT_NE(a, nullptr);
|
||||
ASSERT_EQ(a->input_size(), 1);
|
||||
EXPECT_EQ(a->input(0), "^b");
|
||||
NodeDef* b = graph.GetNode("b");
|
||||
ASSERT_NE(b, nullptr);
|
||||
ASSERT_EQ(b->input_size(), 0);
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, AddControllingFaninSwitchWithIdentity) {
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
{NDef("a", "NotImportant", {}, {}), NDef("switch", "Switch", {}, {}),
|
||||
NDef("identity", "Identity", {"switch"})},
|
||||
/*funcs=*/{});
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
EXPECT_TRUE(graph.AddControllingFanin("a", {"switch", 0}));
|
||||
EXPECT_FALSE(graph.AddControllingFanin("a", {"switch", 0}));
|
||||
|
||||
ASSERT_EQ(graph.graph()->node_size(), 3);
|
||||
NodeDef* a = graph.GetNode("a");
|
||||
ASSERT_NE(a, nullptr);
|
||||
ASSERT_EQ(a->input_size(), 1);
|
||||
EXPECT_EQ(a->input(0), "^identity");
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, AddControllingFaninSwitchWithNoExistingIdentity) {
|
||||
constexpr char kDevice[] = "/device:foo:0";
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
{NDef("a", "NotImportant", {}, {}),
|
||||
NDef("switch", "Switch", {}, {{"T", DT_FLOAT}}, kDevice)},
|
||||
/*funcs=*/{});
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
EXPECT_TRUE(graph.AddControllingFanin("a", {"switch", 0}));
|
||||
EXPECT_FALSE(graph.AddControllingFanin("a", {"switch", 0}));
|
||||
|
||||
ASSERT_EQ(graph.graph()->node_size(), 3);
|
||||
NodeDef* a = graph.GetNode("a");
|
||||
ASSERT_NE(a, nullptr);
|
||||
ASSERT_EQ(a->input_size(), 1);
|
||||
EXPECT_EQ(a->input(0), "^ConstantFoldingCtrl/switch_0");
|
||||
NodeDef* identity = graph.GetNode("ConstantFoldingCtrl/switch_0");
|
||||
ASSERT_NE(identity, nullptr);
|
||||
ASSERT_EQ(identity->input_size(), 1);
|
||||
EXPECT_EQ(identity->input(0), "switch");
|
||||
EXPECT_EQ(identity->op(), "Identity");
|
||||
EXPECT_EQ(identity->device(), kDevice);
|
||||
ASSERT_TRUE(identity->attr().count("T"));
|
||||
EXPECT_EQ(identity->attr().at("T").type(), DT_FLOAT);
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, AddControllingFaninSwitchWithExistingAddedIdentity) {
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
{NDef("a", "NotImportant", {}, {}), NDef("switch", "Switch", {}, {}),
|
||||
NDef("ConstantFoldingCtrl/switch_0", "Identity", {}, {})},
|
||||
/*funcs=*/{});
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
EXPECT_TRUE(graph.AddControllingFanin("a", {"switch", 0}));
|
||||
EXPECT_FALSE(graph.AddControllingFanin("a", {"switch", 0}));
|
||||
|
||||
ASSERT_EQ(graph.graph()->node_size(), 3);
|
||||
NodeDef* a = graph.GetNode("a");
|
||||
ASSERT_NE(a, nullptr);
|
||||
ASSERT_EQ(a->input_size(), 1);
|
||||
EXPECT_EQ(a->input(0), "^ConstantFoldingCtrl/switch_0");
|
||||
}
|
||||
|
||||
TEST(MutableGraphViewTest, DeleteNodes) {
|
||||
// Actual node.op() is not important in this test.
|
||||
GraphDef graph_def = test::function::GDef(
|
||||
@ -469,7 +677,7 @@ TEST(MutableGraphViewTest, DeleteNodes) {
|
||||
NDef("other", "NotImportant", {}, {}),
|
||||
NDef("foo_1", "NotImportant", {"bar", "other", "bar:1", "^bar"}),
|
||||
NDef("foo_2", "NotImportant", {"other:1", "bar:2", "^bar"})},
|
||||
/* empty function library */ {});
|
||||
/*funcs=*/{});
|
||||
|
||||
MutableGraphView graph(&graph_def);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user