[Grappler] Add node fanin mutations in MutableGraphView.

PiperOrigin-RevId: 225474536
This commit is contained in:
Andy Ly 2018-12-13 19:03:48 -08:00 committed by TensorFlower Gardener
parent 3605ae973e
commit 0b0dea8cf1
4 changed files with 619 additions and 6 deletions

View File

@ -176,12 +176,14 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":graph_view",
":grappler_item",
":op_types",
":utils",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
@ -191,7 +193,9 @@ tf_cc_test(
deps = [
":grappler_item",
":mutable_graph_view",
":utils",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:graph",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",

View File

@ -14,14 +14,32 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include <algorithm>
#include <utility>
#include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
namespace {
bool IsTensorIdPortValid(const TensorId& tensor_id) {
return tensor_id.index() >= Graph::kControlSlot;
}
} // namespace
const absl::flat_hash_set<MutableGraphView::InputPort>&
MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
@ -160,17 +178,201 @@ void MutableGraphView::UpdateFanouts(NodeDef* from_node, NodeDef* to_node) {
}
}
bool MutableGraphView::AddFanin(NodeDef* node, const TensorId& fanin) {
NodeDef* fanin_node = GetNode(fanin.node());
if (fanin_node == nullptr) {
return false;
}
int num_non_controlling_fanins =
NumFanins(*node, /*include_controlling_nodes=*/false);
InputPort input;
input.node = node;
input.port_id = fanin.index() == Graph::kControlSlot
? Graph::kControlSlot
: num_non_controlling_fanins;
OutputPort fanin_port(fanin_node, fanin.index());
if (!gtl::InsertIfNotPresent(&fanouts()[fanin_port], input)) {
return false;
}
node->add_input(TensorIdToString(fanin));
if (fanin.index() > Graph::kControlSlot) {
int node_input_size = node->input_size() - 1;
// If there are control dependencies in node, move newly inserted fanin to
// be before such control dependencies.
if (num_non_controlling_fanins < node_input_size) {
node->mutable_input()->SwapElements(node_input_size,
num_non_controlling_fanins);
}
}
return true;
}
bool MutableGraphView::AddFanin(absl::string_view node_name,
const TensorId& fanin) {
if (!IsTensorIdPortValid(fanin)) {
return false;
}
NodeDef* node = GetNode(node_name);
if (node == nullptr) {
return false;
}
return AddFanin(node, fanin);
}
bool MutableGraphView::RemoveFanins(NodeDef* node,
absl::Span<const TensorId> fanins) {
bool modified = false;
auto mutable_inputs = node->mutable_input();
int curr_pos = 0;
int num_inputs = node->input_size();
for (int i = 0; i < num_inputs; ++i) {
TensorId tensor_id = ParseTensorName(node->input(i));
bool remove_fanin =
std::find(fanins.begin(), fanins.end(), tensor_id) != fanins.end();
bool update_fanin = !remove_fanin && modified;
if (remove_fanin || update_fanin) {
OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
InputPort input;
input.node = node;
input.port_id =
tensor_id.index() == Graph::kControlSlot ? Graph::kControlSlot : i;
if (remove_fanin) {
fanouts()[fanin].erase(input);
} else {
// Shift inputs to be retained.
if (tensor_id.index() > Graph::kControlSlot) {
fanouts()[fanin].erase(input);
fanouts()[fanin].insert(InputPort(node, i));
}
mutable_inputs->SwapElements(i, curr_pos++);
}
modified = true;
} else {
// Skip inputs to be retained until first modification.
curr_pos++;
}
}
if (modified) {
mutable_inputs->DeleteSubrange(curr_pos, num_inputs - curr_pos);
}
return modified;
}
bool MutableGraphView::RemoveFanin(absl::string_view node_name,
const TensorId& fanin) {
if (!IsTensorIdPortValid(fanin)) {
return false;
}
NodeDef* node = GetNode(node_name);
if (node == nullptr) {
return false;
}
return RemoveFanins(node, {fanin});
}
bool MutableGraphView::RemoveAllFanins(absl::string_view node_name,
bool keep_controlling_fanins) {
NodeDef* node = GetNode(node_name);
if (node == nullptr || node->input().empty()) {
return false;
}
RemoveFaninsInternal(node, keep_controlling_fanins);
if (keep_controlling_fanins) {
int num_non_controlling_fanins =
NumFanins(*node, /*include_controlling_nodes=*/false);
if (num_non_controlling_fanins == 0) {
return false;
} else if (num_non_controlling_fanins < node->input_size()) {
node->mutable_input()->DeleteSubrange(0, num_non_controlling_fanins);
} else {
node->clear_input();
}
} else {
node->clear_input();
}
return true;
}
bool MutableGraphView::UpdateFanin(absl::string_view node_name,
const TensorId& from_fanin,
const TensorId& to_fanin) {
if (from_fanin == to_fanin || !IsTensorIdPortValid(from_fanin) ||
!IsTensorIdPortValid(to_fanin)) {
return false;
}
NodeDef* node = GetNode(node_name);
if (node == nullptr) {
return false;
}
bool is_from_fanin_control = from_fanin.index() == Graph::kControlSlot;
bool is_to_fanin_control = to_fanin.index() == Graph::kControlSlot;
// When replacing a non control dependency fanin with a control dependency, or
// vice versa, remove and add, so ports can be updated properly in fanout(s).
if (is_from_fanin_control || is_to_fanin_control) {
bool modified = RemoveFanins(node, {from_fanin});
if (!HasFanin(*node, to_fanin)) {
modified |= AddFanin(node, to_fanin);
}
return modified;
}
// In place mutation, requires no shifting of ports.
NodeDef* from_fanin_node = GetNode(from_fanin.node());
NodeDef* to_fanin_node = GetNode(to_fanin.node());
if (from_fanin_node == nullptr || to_fanin_node == nullptr) {
return false;
}
string to_fanin_string = TensorIdToString(to_fanin);
int num_inputs = node->input_size();
bool modified = false;
for (int i = 0; i < num_inputs; ++i) {
if (ParseTensorName(node->input(i)) == from_fanin) {
OutputPort from_fanin_port(from_fanin_node, from_fanin.index());
InputPort old_input;
old_input.node = node;
old_input.port_id =
from_fanin.index() == Graph::kControlSlot ? Graph::kControlSlot : i;
fanouts()[from_fanin_port].erase(old_input);
OutputPort to_fanin_port(to_fanin_node, to_fanin.index());
InputPort new_input;
new_input.node = node;
new_input.port_id =
to_fanin.index() == Graph::kControlSlot ? Graph::kControlSlot : i;
fanouts()[to_fanin_port].insert(new_input);
node->set_input(i, to_fanin_string);
modified = true;
}
}
return modified;
}
void MutableGraphView::DeleteNodes(const std::set<string>& nodes_to_delete) {
for (const string& node_name_to_delete : nodes_to_delete)
RemoveFanouts(nodes().at(node_name_to_delete));
RemoveFaninsInternal(nodes().at(node_name_to_delete),
/*keep_controlling_fanins=*/false);
for (const string& node_name_to_delete : nodes_to_delete)
nodes().erase(node_name_to_delete);
EraseNodesFromGraph(nodes_to_delete, graph());
}
void MutableGraphView::RemoveFanouts(NodeDef* deleted_node) {
void MutableGraphView::RemoveFaninsInternal(NodeDef* deleted_node,
bool keep_controlling_fanins) {
for (int i = 0; i < deleted_node->input_size(); ++i) {
TensorId tensor_id = ParseTensorName(deleted_node->input(i));
if (keep_controlling_fanins && tensor_id.index() < 0) {
break;
}
OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
InputPort input;

View File

@ -16,7 +16,17 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
#define TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
#include <set>
#include <string>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
@ -60,6 +70,38 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
// 2. foo2(new_bar:1, other:1)
void UpdateFanouts(absl::string_view from_node, absl::string_view to_node);
// Add fanin to node `node_name`. If the node or fanin do not exist in the
// graph, nothing will be modified in the graph. If fanin is a control
// dependency, existing control dependencies will be checked first before
// adding. Otherwise fanin will be added after existing non control dependency
// inputs.
//
// This will return true iff the node is modified. If a control dependency
// already exists, the node will not be modified.
bool AddFanin(absl::string_view node_name, const TensorId& fanin);
// Remove fanin from node `node_name`. If the node or fanin do not exist in
// the graph, nothing will be modified in the graph. If there are multiple
// inputs that match the fanin, all of them will be removed.
//
// This will return true iff the node is modified. If no inputs match the
// fanin, the node will not be modified.
bool RemoveFanin(absl::string_view node_name, const TensorId& fanin);
// Remove all fanins from node `node_name`. Control dependencies will be
// retained if keep_controlling_fanins is true.
//
// This will return true iff the node is modified.
bool RemoveAllFanins(absl::string_view node_name,
bool keep_controlling_fanins);
// Replace all fanins `from_fanin` with `to_fanin` in node `node_name`. If
// the fanins or node do not exist, nothing will be modified in the graph.
//
// This will return true iff the node is modified.
bool UpdateFanin(absl::string_view node_name, const TensorId& from_fanin,
const TensorId& to_fanin);
// Deletes nodes from the graph.
void DeleteNodes(const std::set<string>& nodes_to_delete);
@ -79,9 +121,22 @@ class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
// behavior is undefined.
void UpdateFanouts(NodeDef* from_node, NodeDef* to_node);
// Remove fanouts of the deleted node from internal state (including control
// dependencies).
void RemoveFanouts(NodeDef* deleted_node);
// Remove fanins of the deleted node from internal state. Control dependencies
// are retained iff keep_controlling_fanins is true.
void RemoveFaninsInternal(NodeDef* deleted_node,
bool keep_controlling_fanins);
// Add fanin to node. If the node or fanin do not exist in the graph, nothing
// will be modified in the graph. If fanin is a control dependency, existing
// control dependencies will be checked first before adding. Otherwise fanin
// will be added after existing non control dependency inputs.
//
// This will return true iff the node is modified. If a control dependency
// already exists, the node will not be modified.
bool AddFanin(NodeDef* node, const TensorId& fanin);
// Remove any fanin in node that matches to a fanin in fanins.
bool RemoveFanins(NodeDef* node, absl::Span<const TensorId> fanins);
};
} // end namespace grappler

View File

@ -16,8 +16,10 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -110,6 +112,356 @@ TEST(MutableGraphViewTest, AddAndUpdateFanoutsWithoutSelfLoops) {
EXPECT_EQ(new_bar_fanouts.count(MutableGraphView::InputPort(foo, -1)), 1);
}
GraphDef SimpleMutateFaninGraph() {
// Actual node.op() is not important in this test.
GraphDef graph_def = test::function::GDef(
{NDef("a", "NotImportant", {}, {}), NDef("b", "NotImportant", {}, {}),
NDef("c", "NotImportant", {}, {}), NDef("d", "NotImportant", {}, {}),
NDef("foo_1", "NotImportant", {"a"}),
NDef("foo_2", "NotImportant", {"b", "^a", "^c"}),
NDef("foo_3", "NotImportant", {"b", "a:1", "a:1"}),
NDef("foo_4", "NotImportant", {"a", "b:2", "b:2", "^c", "^d"}),
NDef("foo_5", "NotImportant", {}),
NDef("foo_6", "NotImportant", {"^a", "^b"})},
/*funcs=*/{});
return graph_def;
}
void CompareNodeInputs(const MutableGraphView& graph, const NodeDef* expected,
NodeDef* actual) {
ASSERT_EQ(actual->input_size(), expected->input_size());
int port;
for (int i = 0; i < actual->input_size(); ++i) {
EXPECT_EQ(actual->input(i), expected->input(i));
TensorId tensor_id = ParseTensorName(expected->input(i));
if (tensor_id.index() == Graph::kControlSlot) {
port = Graph::kControlSlot;
} else {
port = i;
}
MutableGraphView::InputPort input_port(actual, port);
MutableGraphView::OutputPort output_port =
graph.GetOutputPort(tensor_id.node(), tensor_id.index());
EXPECT_EQ(graph.GetFanin(input_port).contains(output_port), true);
EXPECT_EQ(graph.GetFanout(output_port).contains(input_port), true);
}
}
void TestAddFanin(absl::string_view node_name, const TensorId& fanin_to_add,
bool modified, const NodeDef* expected_node) {
GraphDef graph_def = SimpleMutateFaninGraph();
MutableGraphView graph(&graph_def);
auto node = graph.GetNode(node_name);
if (expected_node == nullptr) {
EXPECT_EQ(node, nullptr);
} else {
EXPECT_NE(node, nullptr);
}
EXPECT_EQ(modified, graph.AddFanin(node_name, fanin_to_add));
if (expected_node != nullptr) {
CompareNodeInputs(graph, expected_node, node);
}
}
TEST(MutableGraphViewTest, AddFanin) {
NodeDef expected_node;
// Add input to node with 1 input 0 controls.
expected_node = NDef("", "", {"a", "b:1"});
TestAddFanin("foo_1", {"b", 1}, /*modified=*/true, &expected_node);
// Add input to node with multiple inputs and 0 controls.
expected_node = NDef("", "", {"b", "a:1", "a:1", "b:2"});
TestAddFanin("foo_3", {"b", 2}, /*modified=*/true, &expected_node);
// Add input to node with 1 input multiple controls.
expected_node = NDef("", "", {"b", "a", "^c", "^a"});
TestAddFanin("foo_2", {"a", 0}, /*modified=*/true, &expected_node);
// Add input to node with multiple inputs and controls.
expected_node = NDef("", "", {"a", "b:2", "b:2", "a:1", "^d", "^c"});
TestAddFanin("foo_4", {"a", 1}, /*modified=*/true, &expected_node);
// Add input to node with 0 inputs 0 controls.
expected_node = NDef("", "", {"a:1"});
TestAddFanin("foo_5", {"a", 1}, /*modified=*/true, &expected_node);
// Add input to node with 0 inputs multiple controls.
expected_node = NDef("", "", {"c:1", "^b", "^a"});
TestAddFanin("foo_6", {"c", 1}, /*modified=*/true, &expected_node);
// Add control to node with 1 input 0 controls.
expected_node = NDef("", "", {"a", "^b"});
TestAddFanin("foo_1", {"b", Graph::kControlSlot}, /*modified=*/true,
&expected_node);
// Add control to node with multiple inputs and 0 controls.
expected_node = NDef("", "", {"b", "a:1", "a:1", "^c"});
TestAddFanin("foo_3", {"c", Graph::kControlSlot}, /*modified=*/true,
&expected_node);
// Add control to node with 1 input multiple controls.
expected_node = NDef("", "", {"b", "^a", "^c", "^d"});
TestAddFanin("foo_2", {"d", Graph::kControlSlot}, /*modified=*/true,
&expected_node);
// Add control to node with multiple input multiple controls.
expected_node = NDef("", "", {"a", "b:2", "b:2", "^c", "^d", "^a"});
TestAddFanin("foo_4", {"a", Graph::kControlSlot}, /*modified=*/true,
&expected_node);
// Add control to node with 0 inputs 0 controls.
expected_node = NDef("", "", {"^a"});
TestAddFanin("foo_5", {"a", Graph::kControlSlot}, /*modified=*/true,
&expected_node);
// Add control to node with 0 inputs multiple controls.
expected_node = NDef("", "", {"^a", "^b", "^c"});
TestAddFanin("foo_6", {"c", Graph::kControlSlot}, /*modified=*/true,
&expected_node);
// Add control to node with control that already exists.
expected_node = NDef("", "", {"b", "^a", "^c"});
TestAddFanin("foo_2", {"a", Graph::kControlSlot}, /*modified=*/false,
&expected_node);
// Add fanin to node where node is missing.
TestAddFanin("foo_missing", {"a", 0}, /*modified=*/false, nullptr);
// Add fanin to node where fanin is missing.
expected_node = NDef("", "", {"a"});
TestAddFanin("foo_1", {"bar_missing", 0}, /*modified=*/false, &expected_node);
// Add fanin to node where node and fanin are missing.
TestAddFanin("foo_missing", {"bar_missing", 0}, /*modified=*/false,
/*expected_node=*/nullptr);
}
void CheckFanout(const MutableGraphView& graph, const TensorId& fanin,
absl::string_view node_name) {
MutableGraphView::OutputPort output_port =
graph.GetOutputPort(fanin.node(), fanin.index());
auto fanouts = graph.GetFanout(output_port);
for (auto fanout : fanouts) {
EXPECT_NE(fanout.node->name(), fanin.node());
}
}
void TestRemoveFanin(absl::string_view node_name,
const TensorId& fanin_to_remove, bool modified,
const NodeDef* expected_node) {
GraphDef graph_def = SimpleMutateFaninGraph();
MutableGraphView graph(&graph_def);
auto node = graph.GetNode(node_name);
if (expected_node == nullptr) {
EXPECT_EQ(nullptr, node);
} else {
EXPECT_NE(nullptr, node);
}
EXPECT_EQ(modified, graph.RemoveFanin(node_name, fanin_to_remove));
if (expected_node != nullptr) {
CompareNodeInputs(graph, expected_node, node);
if (modified) {
CheckFanout(graph, fanin_to_remove, node_name);
}
}
}
TEST(MutableGraphViewTest, RemoveFanin) {
NodeDef expected_node;
// Remove input from node with 1 input 0 controls.
expected_node = NDef("", "", {});
TestRemoveFanin("foo_1", {"a", 0}, /*modified=*/true, &expected_node);
// Remove input from node with multiple inputs and 0 controls.
expected_node = NDef("", "", {"b"});
TestRemoveFanin("foo_3", {"a", 1}, /*modified=*/true, &expected_node);
// Remove input from node with 1 input multiple controls.
expected_node = NDef("", "", {"^a", "^c"});
TestRemoveFanin("foo_2", {"b", 0}, /*modified=*/true, &expected_node);
// Remove input from node with multiple inputs and controls.
expected_node = NDef("", "", {"a", "^c", "^d"});
TestRemoveFanin("foo_4", {"b", 2}, /*modified=*/true, &expected_node);
// Remove control from node with 1 input multiple controls.
expected_node = NDef("", "", {"b", "^c"});
TestRemoveFanin("foo_2", {"a", Graph::kControlSlot}, /*modified=*/true,
&expected_node);
// Remove control from node with multiple input multiple controls.
expected_node = NDef("", "", {"a", "b:2", "b:2", "^c"});
TestRemoveFanin("foo_4", {"d", Graph::kControlSlot}, /*modified=*/true,
&expected_node);
// Remove control from node with 0 inputs multiple controls.
expected_node = NDef("", "", {"^b"});
TestRemoveFanin("foo_6", {"a", Graph::kControlSlot}, /*modified=*/true,
&expected_node);
// Remove input from node with 0 inputs 0 controls.
expected_node = NDef("", "", {});
TestRemoveFanin("foo_5", {"a", 1}, /*modified=*/false, &expected_node);
// Remove input from node with 0 inputs multiple controls.
expected_node = NDef("", "", {"^a", "^b"});
TestRemoveFanin("foo_6", {"a", 1}, /*modified=*/false, &expected_node);
// Remove control from node with 1 input 0 controls.
expected_node = NDef("", "", {"a"});
TestRemoveFanin("foo_1", {"b", Graph::kControlSlot}, /*modified=*/false,
&expected_node);
// Remove control from node with multiple inputs and 0 controls.
expected_node = NDef("", "", {"b", "a:1", "a:1"});
TestRemoveFanin("foo_3", {"c", Graph::kControlSlot}, /*modified=*/false,
&expected_node);
// Remove control from node with 0 inputs 0 controls.
expected_node = NDef("", "", {});
TestRemoveFanin("foo_5", {"a", Graph::kControlSlot}, /*modified=*/false,
&expected_node);
// Remove fanin from node where node is missing.
TestRemoveFanin("foo_missing", {"a", 0}, /*modified=*/false,
/*expected_node=*/nullptr);
// Remove fanin from node where fanin is missing.
expected_node = NDef("", "", {"a"});
TestRemoveFanin("foo_1", {"bar_missing", 0}, /*modified=*/false,
&expected_node);
// Remove fanin from node where node and fanin are missing.
TestRemoveFanin("foo_missing", {"bar_missing", 0}, /*modified=*/false,
/*expected_node=*/nullptr);
}
void TestRemoveAllFanins(absl::string_view node_name,
bool keep_controlling_nodes, bool modified,
const NodeDef* expected_node) {
GraphDef graph_def = SimpleMutateFaninGraph();
MutableGraphView graph(&graph_def);
auto node = graph.GetNode(node_name);
absl::flat_hash_set<string> fanin_strings;
if (expected_node == nullptr) {
EXPECT_EQ(node, nullptr);
} else {
EXPECT_NE(node, nullptr);
fanin_strings.insert(node->input().begin(), node->input().end());
}
EXPECT_EQ(modified, graph.RemoveAllFanins(node_name, keep_controlling_nodes));
if (expected_node != nullptr) {
CompareNodeInputs(graph, expected_node, node);
if (modified) {
TensorId tensor_id;
auto retained_inputs = absl::flat_hash_set<string>(node->input().begin(),
node->input().end());
for (const string& fanin : fanin_strings) {
if (!retained_inputs.contains(fanin)) {
tensor_id = ParseTensorName(fanin);
CheckFanout(graph, tensor_id, node_name);
}
}
}
}
}
TEST(MutableGraphViewTest, RemoveAllFanins) {
NodeDef expected_node;
// Remove all fanins from node with no control dependencies.
expected_node = NDef("", "", {});
TestRemoveAllFanins("foo_3", /*keep_controlling_nodes=*/false,
/*modified=*/true, &expected_node);
// Remove all fanins from node with control dependencies.
TestRemoveAllFanins("foo_4", /*keep_controlling_nodes=*/false,
/*modified=*/true, &expected_node);
// Remove all fanins from node with no control dependencies and preserve
// control dependencies.
TestRemoveAllFanins("foo_3", /*keep_controlling_nodes=*/true,
/*modified=*/true, &expected_node);
// Remove all fanins from node with control dependencies and preserve control
// dependencies.
expected_node = NDef("", "", {"^c", "^d"});
TestRemoveAllFanins("foo_4", /*keep_controlling_nodes=*/true,
/*modified=*/true, &expected_node);
// Remove all fanins from node with no fanins.
expected_node = NDef("", "", {});
TestRemoveAllFanins("foo_5", /*keep_controlling_nodes=*/false,
/*modified=*/false, &expected_node);
TestRemoveAllFanins("foo_5", /*keep_controlling_nodes=*/true,
/*modified=*/false, &expected_node);
// Remove all fanins from node with only control dependencies.
TestRemoveAllFanins("foo_6", /*keep_controlling_nodes=*/false,
/*modified=*/true, &expected_node);
expected_node = NDef("", "", {"^a", "^b"});
TestRemoveAllFanins("foo_6", /*keep_controlling_nodes=*/true,
/*modified=*/false, &expected_node);
// Remove all fanins from node where node is missing.
TestRemoveAllFanins("foo_missing", /*keep_controlling_nodes=*/false,
/*modified=*/false, /*expected_node=*/nullptr);
TestRemoveAllFanins("foo_missing", /*keep_controlling_nodes=*/true,
/*modified=*/false, /*expected_node=*/nullptr);
}
void TestUpdateFanin(absl::string_view node_name, const TensorId& from_fanin,
const TensorId& to_fanin, bool modified,
const NodeDef* expected_node) {
GraphDef graph_def = SimpleMutateFaninGraph();
MutableGraphView graph(&graph_def);
auto node = graph.GetNode(node_name);
if (expected_node == nullptr) {
EXPECT_EQ(node, nullptr);
} else {
EXPECT_NE(node, nullptr);
}
EXPECT_EQ(modified, graph.UpdateFanin(node_name, from_fanin, to_fanin));
if (expected_node != nullptr) {
CompareNodeInputs(graph, expected_node, node);
if (modified) {
CheckFanout(graph, from_fanin, node_name);
}
}
}
TEST(MutableGraphViewTest, UpdateFanin) {
NodeDef expected_node;
// Update fanin from non control to non control.
expected_node = NDef("", "", {"a", "b:3", "b:3", "^c", "^d"});
TestUpdateFanin("foo_4", {"b", 2}, {"b", 3}, /*modified=*/true,
&expected_node);
// Update fanin from non control to control.
expected_node = NDef("", "", {"a", "^c", "^d", "^b"});
TestUpdateFanin("foo_4", {"b", 2}, {"b", Graph::kControlSlot},
/*modified=*/true, &expected_node);
// Update fanin from control to non control.
expected_node = NDef("", "", {"a", "b:2", "b:2", "d:1", "^c"});
TestUpdateFanin("foo_4", {"d", Graph::kControlSlot}, {"d", 1},
/*modified=*/true, &expected_node);
// Update fanin from control to control.
expected_node = NDef("", "", {"a", "b:2", "b:2", "^d", "^b"});
TestUpdateFanin("foo_4", {"c", Graph::kControlSlot},
{"b", Graph::kControlSlot}, /*modified=*/true,
&expected_node);
// Update fanin from control to existing control.
expected_node = NDef("", "", {"a", "b:2", "b:2", "^d"});
TestUpdateFanin("foo_4", {"c", Graph::kControlSlot},
{"d", Graph::kControlSlot}, /*modified=*/true,
&expected_node);
// Update fanin of node where from and to fanins are the same.
expected_node = NDef("", "", {"a"});
TestUpdateFanin("foo_1", {"a", -1}, {"a", -1}, /*modified=*/false,
&expected_node);
TestUpdateFanin("foo_1", {"a", 0}, {"a", 0}, /*modified=*/false,
&expected_node);
TestUpdateFanin("foo_1", {"a", 1}, {"a", 1}, /*modified=*/false,
&expected_node);
// Update fanin of node where node is missing.
TestUpdateFanin("foo_missing", {"a", 0}, {"a", 1}, /*modified=*/false,
/*expected_node=*/nullptr);
// Update fanin of node where from fanin is missing.
TestUpdateFanin("foo_1", {"from_bar_missing", 0}, {"a", 1},
/*modified=*/false, &expected_node);
// Update fanin of node where to fanin is missing.
TestUpdateFanin("foo_1", {"a", 0}, {"to_bar_missing", 1}, /*modified=*/false,
&expected_node);
// Update fanin of node where from/to fanins and node are missing.
TestUpdateFanin("foo_missing", {"from_bar_missing", 0}, {"to_bar_missing", 1},
/*modified=*/false, /*expected_node=*/nullptr);
}
TEST(MutableGraphViewTest, DeleteNodes) {
// Actual node.op() is not important in this test.
GraphDef graph_def = test::function::GDef(